Skip to content

Commit a0c8a0a

Browse files
authored
Move QueryParametersToBodyInterceptor to after custom interceptors (#3547)
1 parent a29f80b commit a0c8a0a

File tree

3 files changed

+117
-6
lines changed

3 files changed

+117
-6
lines changed

codegen/src/main/java/software/amazon/awssdk/codegen/poet/builder/BaseClientBuilderClass.java

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import com.squareup.javapoet.TypeSpec;
2929
import com.squareup.javapoet.TypeVariableName;
3030
import java.util.ArrayList;
31+
import java.util.Collections;
3132
import java.util.List;
3233
import java.util.Map;
3334
import java.util.Optional;
@@ -271,10 +272,6 @@ private MethodSpec finalizeServiceConfigurationMethod() {
271272
ExecutionInterceptor.class),
272273
ArrayList.class);
273274

274-
if (model.getMetadata().isQueryProtocol()) {
275-
builder.addStatement("additionalInterceptors.add(new $T())", QueryParametersToBodyInterceptor.class);
276-
}
277-
278275
builder.addStatement("interceptors = $T.mergeLists(endpointInterceptors, interceptors)",
279276
CollectionUtils.class);
280277
builder.addStatement("interceptors = $T.mergeLists(interceptors, additionalInterceptors)",
@@ -283,6 +280,16 @@ private MethodSpec finalizeServiceConfigurationMethod() {
283280
builder.addCode("interceptors = $T.mergeLists(interceptors, config.option($T.EXECUTION_INTERCEPTORS));\n",
284281
CollectionUtils.class, SdkClientOption.class);
285282

283+
if (model.getMetadata().isQueryProtocol()) {
284+
TypeName listType = ParameterizedTypeName.get(List.class, ExecutionInterceptor.class);
285+
builder.addStatement("$T protocolInterceptors = $T.singletonList(new $T())",
286+
listType,
287+
Collections.class,
288+
QueryParametersToBodyInterceptor.class);
289+
builder.addStatement("interceptors = $T.mergeLists(interceptors, protocolInterceptors)",
290+
CollectionUtils.class);
291+
}
292+
286293
if (model.getEndpointOperation().isPresent()) {
287294
builder.beginControlFlow("if (!endpointDiscoveryEnabled)")
288295
.addStatement("$1T chain = new $1T(config)", DefaultEndpointDiscoveryProviderChain.class)

codegen/src/test/resources/software/amazon/awssdk/codegen/poet/builder/test-query-client-builder-class.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package software.amazon.awssdk.services.query;
22

33
import java.util.ArrayList;
4+
import java.util.Collections;
45
import java.util.List;
56
import software.amazon.awssdk.annotations.Generated;
67
import software.amazon.awssdk.annotations.SdkInternalApi;
@@ -50,10 +51,11 @@ protected final SdkClientConfiguration finalizeServiceConfiguration(SdkClientCon
5051
List<ExecutionInterceptor> interceptors = interceptorFactory
5152
.getInterceptors("software/amazon/awssdk/services/query/execution.interceptors");
5253
List<ExecutionInterceptor> additionalInterceptors = new ArrayList<>();
53-
additionalInterceptors.add(new QueryParametersToBodyInterceptor());
5454
interceptors = CollectionUtils.mergeLists(endpointInterceptors, interceptors);
5555
interceptors = CollectionUtils.mergeLists(interceptors, additionalInterceptors);
5656
interceptors = CollectionUtils.mergeLists(interceptors, config.option(SdkClientOption.EXECUTION_INTERCEPTORS));
57+
List<ExecutionInterceptor> protocolInterceptors = Collections.singletonList(new QueryParametersToBodyInterceptor());
58+
interceptors = CollectionUtils.mergeLists(interceptors, protocolInterceptors);
5759
return config.toBuilder().option(SdkClientOption.EXECUTION_INTERCEPTORS, interceptors)
5860
.option(SdkClientOption.CLIENT_CONTEXT_PARAMS, clientContextParams.build()).build();
5961
}
@@ -74,4 +76,4 @@ private SdkTokenProvider defaultTokenProvider() {
7476
private Signer defaultTokenSigner() {
7577
return BearerTokenSigner.create();
7678
}
77-
}
79+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License").
5+
* You may not use this file except in compliance with the License.
6+
* A copy of the License is located at
7+
*
8+
* http://aws.amazon.com/apache2.0
9+
*
10+
* or in the "license" file accompanying this file. This file is distributed
11+
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
12+
* express or implied. See the License for the specific language governing
13+
* permissions and limitations under the License.
14+
*/
15+
16+
package software.amazon.awssdk.services.protocolquery;
17+
18+
import static org.assertj.core.api.Assertions.assertThat;
19+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
20+
import static org.mockito.ArgumentMatchers.any;
21+
import static org.mockito.Mockito.atLeast;
22+
import static org.mockito.Mockito.mock;
23+
import static org.mockito.Mockito.verify;
24+
import static org.mockito.Mockito.when;
25+
26+
import java.io.IOException;
27+
import java.util.Optional;
28+
import org.junit.jupiter.api.AfterEach;
29+
import org.junit.jupiter.api.BeforeEach;
30+
import org.junit.jupiter.api.Test;
31+
import org.mockito.ArgumentCaptor;
32+
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
33+
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
34+
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
35+
import software.amazon.awssdk.core.exception.SdkClientException;
36+
import software.amazon.awssdk.core.interceptor.Context;
37+
import software.amazon.awssdk.core.interceptor.ExecutionAttributes;
38+
import software.amazon.awssdk.core.interceptor.ExecutionInterceptor;
39+
import software.amazon.awssdk.http.ContentStreamProvider;
40+
import software.amazon.awssdk.http.ExecutableHttpRequest;
41+
import software.amazon.awssdk.http.HttpExecuteRequest;
42+
import software.amazon.awssdk.http.SdkHttpClient;
43+
import software.amazon.awssdk.http.SdkHttpRequest;
44+
import software.amazon.awssdk.regions.Region;
45+
import software.amazon.awssdk.utils.IoUtils;
46+
47+
public class MoveQueryParamsToBodyTest {
48+
private static final AwsCredentialsProvider CREDENTIALS = StaticCredentialsProvider.create(AwsBasicCredentials.create("akid", "skid"));
49+
50+
private SdkHttpClient mockHttpClient;
51+
52+
private ProtocolQueryClient client;
53+
54+
@BeforeEach
55+
public void setup() throws IOException {
56+
mockHttpClient = mock(SdkHttpClient.class);
57+
ExecutableHttpRequest mockRequest = mock(ExecutableHttpRequest.class);
58+
when(mockRequest.call()).thenThrow(new IOException("IO error!"));
59+
when(mockHttpClient.prepareRequest(any())).thenReturn(mockRequest);
60+
}
61+
62+
@AfterEach
63+
public void teardown() {
64+
if (client != null) {
65+
client.close();
66+
}
67+
client = null;
68+
}
69+
70+
@Test
71+
public void customInterceptor_additionalQueryParamsAdded_paramsAlsoMovedToBody() throws IOException {
72+
client = ProtocolQueryClient.builder()
73+
.overrideConfiguration(o -> o.addExecutionInterceptor(new AdditionalQueryParamInterceptor()))
74+
.region(Region.US_WEST_2)
75+
.credentialsProvider(CREDENTIALS)
76+
.httpClient(mockHttpClient)
77+
.build();
78+
79+
ArgumentCaptor<HttpExecuteRequest> requestCaptor = ArgumentCaptor.forClass(HttpExecuteRequest.class);
80+
81+
assertThatThrownBy(() -> client.membersInQueryParams(r -> r.stringQueryParam("hello")))
82+
.isInstanceOf(SdkClientException.class)
83+
.hasMessageContaining("IO");
84+
85+
verify(mockHttpClient, atLeast(1)).prepareRequest(requestCaptor.capture());
86+
87+
ContentStreamProvider requestContent = requestCaptor.getValue().contentStreamProvider().get();
88+
89+
String contentString = IoUtils.toUtf8String(requestContent.newStream());
90+
91+
assertThat(contentString).contains("CustomParamName=CustomParamValue");
92+
}
93+
94+
private static class AdditionalQueryParamInterceptor implements ExecutionInterceptor {
95+
@Override
96+
public SdkHttpRequest modifyHttpRequest(Context.ModifyHttpRequest context, ExecutionAttributes executionAttributes) {
97+
return context.httpRequest().toBuilder()
98+
.putRawQueryParameter("CustomParamName", "CustomParamValue")
99+
.build();
100+
}
101+
}
102+
}

0 commit comments

Comments
 (0)