Skip to content

Commit cc254ef

Browse files
committed
Don't rely on endpoint params region
When generating the for the endpoint interceptor, don't assume that the endpoint params contains a 'region' member; endpoint rulesets are not required to have this parameter; in situations where the endpoint params does not have this member, compilation fails. For backwards compatibility, give the region endpoint params preference if it exists; otherwise use the region from the execution attributes.
1 parent 0626ab7 commit cc254ef

File tree

3 files changed

+208
-4
lines changed

3 files changed

+208
-4
lines changed

codegen/src/main/java/software/amazon/awssdk/codegen/poet/rules/EndpointResolverInterceptorSpec.java

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -234,8 +234,22 @@ private MethodSpec modifyRequestMethod(String endpointAuthSchemeStrategyFieldNam
234234
AwsV4aAuthScheme.class, AwsV4aHttpSigner.class);
235235
b.addStatement("$T optionBuilder = selectedAuthScheme.authSchemeOption().toBuilder()",
236236
AuthSchemeOption.Builder.class);
237-
b.addStatement("$T regionSet = $T.create(endpointParams.region().id())",
238-
RegionSet.class, RegionSet.class);
237+
238+
// Note: initially, we assumed that endpointParams contains a region() member, but endpoint rules does not require
239+
// this.
240+
//
241+
// For backwards compatibility reasons, we first check if the endpoint params has an explicit "region" parameter.
242+
// If so, use that. Note that a "region" ruleset param *may not* match the region set on the client.
243+
//
244+
// Otherwise, fallback to the client region.
245+
CodeBlock regionExpr;
246+
if (endpointRulesSpecUtils.isDeclaredParam("region")) {
247+
regionExpr = CodeBlock.of("endpointParams.region().id()");
248+
} else {
249+
regionExpr = CodeBlock.of("executionAttributes.getAttribute(AwsExecutionAttribute.AWS_REGION).id()");
250+
}
251+
252+
b.addStatement("$T regionSet = $T.create($L)", RegionSet.class, RegionSet.class, regionExpr);
239253
b.addStatement("optionBuilder.putSignerProperty($T.REGION_SET, regionSet)", AwsV4aHttpSigner.class);
240254
b.addStatement("selectedAuthScheme = new $T(selectedAuthScheme.identity(), selectedAuthScheme.signer(), "
241255
+ "optionBuilder.build())", SelectedAuthScheme.class);

codegen/src/test/java/software/amazon/awssdk/codegen/poet/rules/EndpointResolverInterceptorSpecTest.java

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import org.junit.jupiter.api.Test;
2222
import software.amazon.awssdk.codegen.model.intermediate.IntermediateModel;
23+
import software.amazon.awssdk.codegen.model.rules.endpoints.ParameterModel;
2324
import software.amazon.awssdk.codegen.poet.ClassSpec;
2425
import software.amazon.awssdk.codegen.poet.ClientTestModels;
2526

@@ -44,11 +45,24 @@ private static IntermediateModel getModel(boolean useSraAuth) {
4445
}
4546

4647
@Test
47-
void endpointResolverInterceptorClassWithSigv4aMultiAuth() {
48-
ClassSpec endpointProviderInterceptor = new EndpointResolverInterceptorSpec(ClientTestModels.opsWithSigv4a());
48+
void endpointResolverInterceptorClassWithSigv4aMultiAuth_withRegionParameter() {
49+
IntermediateModel intermediateModel = ClientTestModels.opsWithSigv4a();
50+
51+
ParameterModel region = new ParameterModel();
52+
region.setType("string");
53+
intermediateModel.getEndpointRuleSetModel().getParameters().put("region", region);
54+
ClassSpec endpointProviderInterceptor = new EndpointResolverInterceptorSpec(intermediateModel);
55+
4956
assertThat(endpointProviderInterceptor, generatesTo("endpoint-resolve-interceptor-with-multiauthsigv4a.java"));
5057
}
5158

59+
@Test
60+
void endpointResolverInterceptorClassWithSigv4aMultiAuth_noRegionParameter() {
61+
ClassSpec endpointProviderInterceptor = new EndpointResolverInterceptorSpec(ClientTestModels.opsWithSigv4a());
62+
assertThat(endpointProviderInterceptor,
63+
generatesTo("endpoint-resolve-interceptor-with-multiauthsigv4a-noregionparam.java"));
64+
}
65+
5266
@Test
5367
void endpointResolverInterceptorClassWithEndpointBasedAuth() {
5468
ClassSpec endpointProviderInterceptor = new EndpointResolverInterceptorSpec(ClientTestModels.queryServiceModelsEndpointAuthParamsWithoutAllowList());
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
package software.amazon.awssdk.services.database.endpoints.internal;
2+
3+
import java.time.Duration;
4+
import java.util.List;
5+
import java.util.Optional;
6+
import java.util.concurrent.CompletionException;
7+
import software.amazon.awssdk.annotations.Generated;
8+
import software.amazon.awssdk.annotations.SdkInternalApi;
9+
import software.amazon.awssdk.awscore.AwsExecutionAttribute;
10+
import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute;
11+
import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme;
12+
import software.amazon.awssdk.awscore.endpoints.authscheme.SigV4AuthScheme;
13+
import software.amazon.awssdk.awscore.endpoints.authscheme.SigV4aAuthScheme;
14+
import software.amazon.awssdk.core.SdkRequest;
15+
import software.amazon.awssdk.core.SelectedAuthScheme;
16+
import software.amazon.awssdk.core.exception.SdkClientException;
17+
import software.amazon.awssdk.core.interceptor.Context;
18+
import software.amazon.awssdk.core.interceptor.ExecutionAttributes;
19+
import software.amazon.awssdk.core.interceptor.ExecutionInterceptor;
20+
import software.amazon.awssdk.core.interceptor.SdkExecutionAttribute;
21+
import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute;
22+
import software.amazon.awssdk.core.metrics.CoreMetric;
23+
import software.amazon.awssdk.endpoints.Endpoint;
24+
import software.amazon.awssdk.http.SdkHttpRequest;
25+
import software.amazon.awssdk.http.auth.aws.scheme.AwsV4aAuthScheme;
26+
import software.amazon.awssdk.http.auth.aws.signer.AwsV4HttpSigner;
27+
import software.amazon.awssdk.http.auth.aws.signer.AwsV4aHttpSigner;
28+
import software.amazon.awssdk.http.auth.aws.signer.RegionSet;
29+
import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption;
30+
import software.amazon.awssdk.identity.spi.Identity;
31+
import software.amazon.awssdk.metrics.MetricCollector;
32+
import software.amazon.awssdk.services.database.endpoints.DatabaseEndpointParams;
33+
import software.amazon.awssdk.services.database.endpoints.DatabaseEndpointProvider;
34+
import software.amazon.awssdk.utils.CollectionUtils;
35+
36+
@Generated("software.amazon.awssdk:codegen")
37+
@SdkInternalApi
38+
public final class DatabaseResolveEndpointInterceptor implements ExecutionInterceptor {
39+
@Override
40+
public SdkRequest modifyRequest(Context.ModifyRequest context, ExecutionAttributes executionAttributes) {
41+
SdkRequest result = context.request();
42+
if (AwsEndpointProviderUtils.endpointIsDiscovered(executionAttributes)) {
43+
return result;
44+
}
45+
DatabaseEndpointProvider provider = (DatabaseEndpointProvider) executionAttributes
46+
.getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER);
47+
try {
48+
long resolveEndpointStart = System.nanoTime();
49+
DatabaseEndpointParams endpointParams = ruleParams(result, executionAttributes);
50+
Endpoint endpoint = provider.resolveEndpoint(endpointParams).join();
51+
Duration resolveEndpointDuration = Duration.ofNanos(System.nanoTime() - resolveEndpointStart);
52+
Optional<MetricCollector> metricCollector = executionAttributes
53+
.getOptionalAttribute(SdkExecutionAttribute.API_CALL_METRIC_COLLECTOR);
54+
metricCollector.ifPresent(mc -> mc.reportMetric(CoreMetric.ENDPOINT_RESOLVE_DURATION, resolveEndpointDuration));
55+
if (!AwsEndpointProviderUtils.disableHostPrefixInjection(executionAttributes)) {
56+
Optional<String> hostPrefix = hostPrefix(executionAttributes.getAttribute(SdkExecutionAttribute.OPERATION_NAME),
57+
result);
58+
if (hostPrefix.isPresent()) {
59+
endpoint = AwsEndpointProviderUtils.addHostPrefix(endpoint, hostPrefix.get());
60+
}
61+
}
62+
List<EndpointAuthScheme> endpointAuthSchemes = endpoint.attribute(AwsEndpointAttribute.AUTH_SCHEMES);
63+
SelectedAuthScheme<?> selectedAuthScheme = executionAttributes
64+
.getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME);
65+
if (endpointAuthSchemes != null && selectedAuthScheme != null) {
66+
selectedAuthScheme = authSchemeWithEndpointSignerProperties(endpointAuthSchemes, selectedAuthScheme);
67+
// Precedence of SigV4a RegionSet is set according to multi-auth SigV4a specifications
68+
if (selectedAuthScheme.authSchemeOption().schemeId().equals(AwsV4aAuthScheme.SCHEME_ID)
69+
&& selectedAuthScheme.authSchemeOption().signerProperty(AwsV4aHttpSigner.REGION_SET) == null) {
70+
AuthSchemeOption.Builder optionBuilder = selectedAuthScheme.authSchemeOption().toBuilder();
71+
RegionSet regionSet = RegionSet.create(executionAttributes.getAttribute(AwsExecutionAttribute.AWS_REGION)
72+
.id());
73+
optionBuilder.putSignerProperty(AwsV4aHttpSigner.REGION_SET, regionSet);
74+
selectedAuthScheme = new SelectedAuthScheme(selectedAuthScheme.identity(), selectedAuthScheme.signer(),
75+
optionBuilder.build());
76+
}
77+
executionAttributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, selectedAuthScheme);
78+
}
79+
executionAttributes.putAttribute(SdkInternalExecutionAttribute.RESOLVED_ENDPOINT, endpoint);
80+
setMetricValues(endpoint, executionAttributes);
81+
return result;
82+
} catch (CompletionException e) {
83+
Throwable cause = e.getCause();
84+
if (cause instanceof SdkClientException) {
85+
throw (SdkClientException) cause;
86+
} else {
87+
throw SdkClientException.create("Endpoint resolution failed", cause);
88+
}
89+
}
90+
}
91+
92+
@Override
93+
public SdkHttpRequest modifyHttpRequest(Context.ModifyHttpRequest context, ExecutionAttributes executionAttributes) {
94+
Endpoint resolvedEndpoint = executionAttributes.getAttribute(SdkInternalExecutionAttribute.RESOLVED_ENDPOINT);
95+
if (resolvedEndpoint.headers().isEmpty()) {
96+
return context.httpRequest();
97+
}
98+
SdkHttpRequest.Builder httpRequestBuilder = context.httpRequest().toBuilder();
99+
resolvedEndpoint.headers().forEach((name, values) -> {
100+
values.forEach(v -> httpRequestBuilder.appendHeader(name, v));
101+
});
102+
return httpRequestBuilder.build();
103+
}
104+
105+
public static DatabaseEndpointParams ruleParams(SdkRequest request, ExecutionAttributes executionAttributes) {
106+
DatabaseEndpointParams.Builder builder = DatabaseEndpointParams.builder();
107+
builder.region(AwsEndpointProviderUtils.regionBuiltIn(executionAttributes));
108+
builder.endpoint(AwsEndpointProviderUtils.endpointBuiltIn(executionAttributes));
109+
setContextParams(builder, executionAttributes.getAttribute(AwsExecutionAttribute.OPERATION_NAME), request);
110+
setStaticContextParams(builder, executionAttributes.getAttribute(AwsExecutionAttribute.OPERATION_NAME));
111+
setOperationContextParams(builder, executionAttributes.getAttribute(AwsExecutionAttribute.OPERATION_NAME), request);
112+
return builder.build();
113+
}
114+
115+
private static void setContextParams(DatabaseEndpointParams.Builder params, String operationName, SdkRequest request) {
116+
}
117+
118+
private static void setStaticContextParams(DatabaseEndpointParams.Builder params, String operationName) {
119+
}
120+
121+
private <T extends Identity> SelectedAuthScheme<T> authSchemeWithEndpointSignerProperties(
122+
List<EndpointAuthScheme> endpointAuthSchemes, SelectedAuthScheme<T> selectedAuthScheme) {
123+
for (EndpointAuthScheme endpointAuthScheme : endpointAuthSchemes) {
124+
if (!endpointAuthScheme.schemeId().equals(selectedAuthScheme.authSchemeOption().schemeId())) {
125+
continue;
126+
}
127+
AuthSchemeOption.Builder option = selectedAuthScheme.authSchemeOption().toBuilder();
128+
if (endpointAuthScheme instanceof SigV4AuthScheme) {
129+
SigV4AuthScheme v4AuthScheme = (SigV4AuthScheme) endpointAuthScheme;
130+
if (v4AuthScheme.isDisableDoubleEncodingSet()) {
131+
option.putSignerProperty(AwsV4HttpSigner.DOUBLE_URL_ENCODE, !v4AuthScheme.disableDoubleEncoding());
132+
}
133+
if (v4AuthScheme.signingRegion() != null) {
134+
option.putSignerProperty(AwsV4HttpSigner.REGION_NAME, v4AuthScheme.signingRegion());
135+
}
136+
if (v4AuthScheme.signingName() != null) {
137+
option.putSignerProperty(AwsV4HttpSigner.SERVICE_SIGNING_NAME, v4AuthScheme.signingName());
138+
}
139+
return new SelectedAuthScheme<>(selectedAuthScheme.identity(), selectedAuthScheme.signer(), option.build());
140+
}
141+
if (endpointAuthScheme instanceof SigV4aAuthScheme) {
142+
SigV4aAuthScheme v4aAuthScheme = (SigV4aAuthScheme) endpointAuthScheme;
143+
if (v4aAuthScheme.isDisableDoubleEncodingSet()) {
144+
option.putSignerProperty(AwsV4aHttpSigner.DOUBLE_URL_ENCODE, !v4aAuthScheme.disableDoubleEncoding());
145+
}
146+
if (!(selectedAuthScheme.authSchemeOption().schemeId().equals(AwsV4aAuthScheme.SCHEME_ID) && selectedAuthScheme
147+
.authSchemeOption().signerProperty(AwsV4aHttpSigner.REGION_SET) != null)
148+
&& !CollectionUtils.isNullOrEmpty(v4aAuthScheme.signingRegionSet())) {
149+
RegionSet regionSet = RegionSet.create(v4aAuthScheme.signingRegionSet());
150+
option.putSignerProperty(AwsV4aHttpSigner.REGION_SET, regionSet);
151+
}
152+
if (v4aAuthScheme.signingName() != null) {
153+
option.putSignerProperty(AwsV4aHttpSigner.SERVICE_SIGNING_NAME, v4aAuthScheme.signingName());
154+
}
155+
return new SelectedAuthScheme<>(selectedAuthScheme.identity(), selectedAuthScheme.signer(), option.build());
156+
}
157+
throw new IllegalArgumentException("Endpoint auth scheme '" + endpointAuthScheme.name()
158+
+ "' cannot be mapped to the SDK auth scheme. Was it declared in the service's model?");
159+
}
160+
return selectedAuthScheme;
161+
}
162+
163+
private static void setOperationContextParams(DatabaseEndpointParams.Builder params, String operationName, SdkRequest request) {
164+
}
165+
166+
private static Optional<String> hostPrefix(String operationName, SdkRequest request) {
167+
return Optional.empty();
168+
}
169+
170+
private void setMetricValues(Endpoint endpoint, ExecutionAttributes executionAttributes) {
171+
if (endpoint.attribute(AwsEndpointAttribute.METRIC_VALUES) != null) {
172+
executionAttributes.getOptionalAttribute(SdkInternalExecutionAttribute.BUSINESS_METRICS).ifPresent(
173+
metrics -> endpoint.attribute(AwsEndpointAttribute.METRIC_VALUES).forEach(v -> metrics.addMetric(v)));
174+
}
175+
}
176+
}

0 commit comments

Comments
 (0)