diff --git a/.changes/next-release/bugfix-AWSSDKforJavav2-d0c7f24.json b/.changes/next-release/bugfix-AWSSDKforJavav2-d0c7f24.json new file mode 100644 index 000000000000..342dac4f8ec5 --- /dev/null +++ b/.changes/next-release/bugfix-AWSSDKforJavav2-d0c7f24.json @@ -0,0 +1,6 @@ +{ + "type": "bugfix", + "category": "AWS SDK for Java v2", + "contributor": "", + "description": "When generating the code for the endpoint interceptor, don't assume that the endpoint params contains a 'region' member; endpoint rulesets are not required to have this parameter; in situations where the endpoint params does not have this member, compilation fails.\n\nFor backwards compatibility, give the region endpoint params preference if it exists; otherwise use the region from the execution attributes." +} diff --git a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/rules/EndpointResolverInterceptorSpec.java b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/rules/EndpointResolverInterceptorSpec.java index 7bac753e88e5..6e4c623384b9 100644 --- a/codegen/src/main/java/software/amazon/awssdk/codegen/poet/rules/EndpointResolverInterceptorSpec.java +++ b/codegen/src/main/java/software/amazon/awssdk/codegen/poet/rules/EndpointResolverInterceptorSpec.java @@ -234,8 +234,22 @@ private MethodSpec modifyRequestMethod(String endpointAuthSchemeStrategyFieldNam AwsV4aAuthScheme.class, AwsV4aHttpSigner.class); b.addStatement("$T optionBuilder = selectedAuthScheme.authSchemeOption().toBuilder()", AuthSchemeOption.Builder.class); - b.addStatement("$T regionSet = $T.create(endpointParams.region().id())", - RegionSet.class, RegionSet.class); + + // Note: initially, we assumed that endpointParams contains a region() member, but endpoint rules does not require + // this. + // + // For backwards compatibility reasons, we first check if the endpoint params has an explicit "region" parameter. + // If so, use that. Note that a "region" ruleset param *may not* match the region set on the client. + // + // Otherwise, fallback to the client region. + CodeBlock regionExpr; + if (endpointRulesSpecUtils.isDeclaredParam("region")) { + regionExpr = CodeBlock.of("endpointParams.region().id()"); + } else { + regionExpr = CodeBlock.of("executionAttributes.getAttribute(AwsExecutionAttribute.AWS_REGION).id()"); + } + + b.addStatement("$T regionSet = $T.create($L)", RegionSet.class, RegionSet.class, regionExpr); b.addStatement("optionBuilder.putSignerProperty($T.REGION_SET, regionSet)", AwsV4aHttpSigner.class); b.addStatement("selectedAuthScheme = new $T(selectedAuthScheme.identity(), selectedAuthScheme.signer(), " + "optionBuilder.build())", SelectedAuthScheme.class); diff --git a/codegen/src/test/java/software/amazon/awssdk/codegen/poet/rules/EndpointResolverInterceptorSpecTest.java b/codegen/src/test/java/software/amazon/awssdk/codegen/poet/rules/EndpointResolverInterceptorSpecTest.java index 77c5a44e3203..82874c1a2440 100644 --- a/codegen/src/test/java/software/amazon/awssdk/codegen/poet/rules/EndpointResolverInterceptorSpecTest.java +++ b/codegen/src/test/java/software/amazon/awssdk/codegen/poet/rules/EndpointResolverInterceptorSpecTest.java @@ -20,6 +20,7 @@ import org.junit.jupiter.api.Test; import software.amazon.awssdk.codegen.model.intermediate.IntermediateModel; +import software.amazon.awssdk.codegen.model.rules.endpoints.ParameterModel; import software.amazon.awssdk.codegen.poet.ClassSpec; import software.amazon.awssdk.codegen.poet.ClientTestModels; @@ -44,11 +45,24 @@ private static IntermediateModel getModel(boolean useSraAuth) { } @Test - void endpointResolverInterceptorClassWithSigv4aMultiAuth() { - ClassSpec endpointProviderInterceptor = new EndpointResolverInterceptorSpec(ClientTestModels.opsWithSigv4a()); + void endpointResolverInterceptorClassWithSigv4aMultiAuth_withRegionParameter() { + IntermediateModel intermediateModel = ClientTestModels.opsWithSigv4a(); + + ParameterModel region = new ParameterModel(); + region.setType("string"); + intermediateModel.getEndpointRuleSetModel().getParameters().put("region", region); + ClassSpec endpointProviderInterceptor = new EndpointResolverInterceptorSpec(intermediateModel); + assertThat(endpointProviderInterceptor, generatesTo("endpoint-resolve-interceptor-with-multiauthsigv4a.java")); } + @Test + void endpointResolverInterceptorClassWithSigv4aMultiAuth_noRegionParameter() { + ClassSpec endpointProviderInterceptor = new EndpointResolverInterceptorSpec(ClientTestModels.opsWithSigv4a()); + assertThat(endpointProviderInterceptor, + generatesTo("endpoint-resolve-interceptor-with-multiauthsigv4a-noregionparam.java")); + } + @Test void endpointResolverInterceptorClassWithEndpointBasedAuth() { ClassSpec endpointProviderInterceptor = new EndpointResolverInterceptorSpec(ClientTestModels.queryServiceModelsEndpointAuthParamsWithoutAllowList()); diff --git a/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/rules/endpoint-resolve-interceptor-with-multiauthsigv4a-noregionparam.java b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/rules/endpoint-resolve-interceptor-with-multiauthsigv4a-noregionparam.java new file mode 100644 index 000000000000..90c82e04a83d --- /dev/null +++ b/codegen/src/test/resources/software/amazon/awssdk/codegen/poet/rules/endpoint-resolve-interceptor-with-multiauthsigv4a-noregionparam.java @@ -0,0 +1,176 @@ +package software.amazon.awssdk.services.database.endpoints.internal; + +import java.time.Duration; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.CompletionException; +import software.amazon.awssdk.annotations.Generated; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.awscore.AwsExecutionAttribute; +import software.amazon.awssdk.awscore.endpoints.AwsEndpointAttribute; +import software.amazon.awssdk.awscore.endpoints.authscheme.EndpointAuthScheme; +import software.amazon.awssdk.awscore.endpoints.authscheme.SigV4AuthScheme; +import software.amazon.awssdk.awscore.endpoints.authscheme.SigV4aAuthScheme; +import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.SelectedAuthScheme; +import software.amazon.awssdk.core.exception.SdkClientException; +import software.amazon.awssdk.core.interceptor.Context; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.ExecutionInterceptor; +import software.amazon.awssdk.core.interceptor.SdkExecutionAttribute; +import software.amazon.awssdk.core.interceptor.SdkInternalExecutionAttribute; +import software.amazon.awssdk.core.metrics.CoreMetric; +import software.amazon.awssdk.endpoints.Endpoint; +import software.amazon.awssdk.http.SdkHttpRequest; +import software.amazon.awssdk.http.auth.aws.scheme.AwsV4aAuthScheme; +import software.amazon.awssdk.http.auth.aws.signer.AwsV4HttpSigner; +import software.amazon.awssdk.http.auth.aws.signer.AwsV4aHttpSigner; +import software.amazon.awssdk.http.auth.aws.signer.RegionSet; +import software.amazon.awssdk.http.auth.spi.scheme.AuthSchemeOption; +import software.amazon.awssdk.identity.spi.Identity; +import software.amazon.awssdk.metrics.MetricCollector; +import software.amazon.awssdk.services.database.endpoints.DatabaseEndpointParams; +import software.amazon.awssdk.services.database.endpoints.DatabaseEndpointProvider; +import software.amazon.awssdk.utils.CollectionUtils; + +@Generated("software.amazon.awssdk:codegen") +@SdkInternalApi +public final class DatabaseResolveEndpointInterceptor implements ExecutionInterceptor { + @Override + public SdkRequest modifyRequest(Context.ModifyRequest context, ExecutionAttributes executionAttributes) { + SdkRequest result = context.request(); + if (AwsEndpointProviderUtils.endpointIsDiscovered(executionAttributes)) { + return result; + } + DatabaseEndpointProvider provider = (DatabaseEndpointProvider) executionAttributes + .getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER); + try { + long resolveEndpointStart = System.nanoTime(); + DatabaseEndpointParams endpointParams = ruleParams(result, executionAttributes); + Endpoint endpoint = provider.resolveEndpoint(endpointParams).join(); + Duration resolveEndpointDuration = Duration.ofNanos(System.nanoTime() - resolveEndpointStart); + Optional metricCollector = executionAttributes + .getOptionalAttribute(SdkExecutionAttribute.API_CALL_METRIC_COLLECTOR); + metricCollector.ifPresent(mc -> mc.reportMetric(CoreMetric.ENDPOINT_RESOLVE_DURATION, resolveEndpointDuration)); + if (!AwsEndpointProviderUtils.disableHostPrefixInjection(executionAttributes)) { + Optional hostPrefix = hostPrefix(executionAttributes.getAttribute(SdkExecutionAttribute.OPERATION_NAME), + result); + if (hostPrefix.isPresent()) { + endpoint = AwsEndpointProviderUtils.addHostPrefix(endpoint, hostPrefix.get()); + } + } + List endpointAuthSchemes = endpoint.attribute(AwsEndpointAttribute.AUTH_SCHEMES); + SelectedAuthScheme selectedAuthScheme = executionAttributes + .getAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME); + if (endpointAuthSchemes != null && selectedAuthScheme != null) { + selectedAuthScheme = authSchemeWithEndpointSignerProperties(endpointAuthSchemes, selectedAuthScheme); + // Precedence of SigV4a RegionSet is set according to multi-auth SigV4a specifications + if (selectedAuthScheme.authSchemeOption().schemeId().equals(AwsV4aAuthScheme.SCHEME_ID) + && selectedAuthScheme.authSchemeOption().signerProperty(AwsV4aHttpSigner.REGION_SET) == null) { + AuthSchemeOption.Builder optionBuilder = selectedAuthScheme.authSchemeOption().toBuilder(); + RegionSet regionSet = RegionSet.create(executionAttributes.getAttribute(AwsExecutionAttribute.AWS_REGION) + .id()); + optionBuilder.putSignerProperty(AwsV4aHttpSigner.REGION_SET, regionSet); + selectedAuthScheme = new SelectedAuthScheme(selectedAuthScheme.identity(), selectedAuthScheme.signer(), + optionBuilder.build()); + } + executionAttributes.putAttribute(SdkInternalExecutionAttribute.SELECTED_AUTH_SCHEME, selectedAuthScheme); + } + executionAttributes.putAttribute(SdkInternalExecutionAttribute.RESOLVED_ENDPOINT, endpoint); + setMetricValues(endpoint, executionAttributes); + return result; + } catch (CompletionException e) { + Throwable cause = e.getCause(); + if (cause instanceof SdkClientException) { + throw (SdkClientException) cause; + } else { + throw SdkClientException.create("Endpoint resolution failed", cause); + } + } + } + + @Override + public SdkHttpRequest modifyHttpRequest(Context.ModifyHttpRequest context, ExecutionAttributes executionAttributes) { + Endpoint resolvedEndpoint = executionAttributes.getAttribute(SdkInternalExecutionAttribute.RESOLVED_ENDPOINT); + if (resolvedEndpoint.headers().isEmpty()) { + return context.httpRequest(); + } + SdkHttpRequest.Builder httpRequestBuilder = context.httpRequest().toBuilder(); + resolvedEndpoint.headers().forEach((name, values) -> { + values.forEach(v -> httpRequestBuilder.appendHeader(name, v)); + }); + return httpRequestBuilder.build(); + } + + public static DatabaseEndpointParams ruleParams(SdkRequest request, ExecutionAttributes executionAttributes) { + DatabaseEndpointParams.Builder builder = DatabaseEndpointParams.builder(); + builder.region(AwsEndpointProviderUtils.regionBuiltIn(executionAttributes)); + builder.endpoint(AwsEndpointProviderUtils.endpointBuiltIn(executionAttributes)); + setContextParams(builder, executionAttributes.getAttribute(AwsExecutionAttribute.OPERATION_NAME), request); + setStaticContextParams(builder, executionAttributes.getAttribute(AwsExecutionAttribute.OPERATION_NAME)); + setOperationContextParams(builder, executionAttributes.getAttribute(AwsExecutionAttribute.OPERATION_NAME), request); + return builder.build(); + } + + private static void setContextParams(DatabaseEndpointParams.Builder params, String operationName, SdkRequest request) { + } + + private static void setStaticContextParams(DatabaseEndpointParams.Builder params, String operationName) { + } + + private SelectedAuthScheme authSchemeWithEndpointSignerProperties( + List endpointAuthSchemes, SelectedAuthScheme selectedAuthScheme) { + for (EndpointAuthScheme endpointAuthScheme : endpointAuthSchemes) { + if (!endpointAuthScheme.schemeId().equals(selectedAuthScheme.authSchemeOption().schemeId())) { + continue; + } + AuthSchemeOption.Builder option = selectedAuthScheme.authSchemeOption().toBuilder(); + if (endpointAuthScheme instanceof SigV4AuthScheme) { + SigV4AuthScheme v4AuthScheme = (SigV4AuthScheme) endpointAuthScheme; + if (v4AuthScheme.isDisableDoubleEncodingSet()) { + option.putSignerProperty(AwsV4HttpSigner.DOUBLE_URL_ENCODE, !v4AuthScheme.disableDoubleEncoding()); + } + if (v4AuthScheme.signingRegion() != null) { + option.putSignerProperty(AwsV4HttpSigner.REGION_NAME, v4AuthScheme.signingRegion()); + } + if (v4AuthScheme.signingName() != null) { + option.putSignerProperty(AwsV4HttpSigner.SERVICE_SIGNING_NAME, v4AuthScheme.signingName()); + } + return new SelectedAuthScheme<>(selectedAuthScheme.identity(), selectedAuthScheme.signer(), option.build()); + } + if (endpointAuthScheme instanceof SigV4aAuthScheme) { + SigV4aAuthScheme v4aAuthScheme = (SigV4aAuthScheme) endpointAuthScheme; + if (v4aAuthScheme.isDisableDoubleEncodingSet()) { + option.putSignerProperty(AwsV4aHttpSigner.DOUBLE_URL_ENCODE, !v4aAuthScheme.disableDoubleEncoding()); + } + if (!(selectedAuthScheme.authSchemeOption().schemeId().equals(AwsV4aAuthScheme.SCHEME_ID) && selectedAuthScheme + .authSchemeOption().signerProperty(AwsV4aHttpSigner.REGION_SET) != null) + && !CollectionUtils.isNullOrEmpty(v4aAuthScheme.signingRegionSet())) { + RegionSet regionSet = RegionSet.create(v4aAuthScheme.signingRegionSet()); + option.putSignerProperty(AwsV4aHttpSigner.REGION_SET, regionSet); + } + if (v4aAuthScheme.signingName() != null) { + option.putSignerProperty(AwsV4aHttpSigner.SERVICE_SIGNING_NAME, v4aAuthScheme.signingName()); + } + return new SelectedAuthScheme<>(selectedAuthScheme.identity(), selectedAuthScheme.signer(), option.build()); + } + throw new IllegalArgumentException("Endpoint auth scheme '" + endpointAuthScheme.name() + + "' cannot be mapped to the SDK auth scheme. Was it declared in the service's model?"); + } + return selectedAuthScheme; + } + + private static void setOperationContextParams(DatabaseEndpointParams.Builder params, String operationName, SdkRequest request) { + } + + private static Optional hostPrefix(String operationName, SdkRequest request) { + return Optional.empty(); + } + + private void setMetricValues(Endpoint endpoint, ExecutionAttributes executionAttributes) { + if (endpoint.attribute(AwsEndpointAttribute.METRIC_VALUES) != null) { + executionAttributes.getOptionalAttribute(SdkInternalExecutionAttribute.BUSINESS_METRICS).ifPresent( + metrics -> endpoint.attribute(AwsEndpointAttribute.METRIC_VALUES).forEach(v -> metrics.addMetric(v))); + } + } +}