diff --git a/.changes/20e287a5-cee0-4a5f-9d08-4022dcdee843.json b/.changes/20e287a5-cee0-4a5f-9d08-4022dcdee843.json new file mode 100644 index 00000000000..37442f4cdd2 --- /dev/null +++ b/.changes/20e287a5-cee0-4a5f-9d08-4022dcdee843.json @@ -0,0 +1,5 @@ +{ + "id": "20e287a5-cee0-4a5f-9d08-4022dcdee843", + "type": "misc", + "description": "Send x-amzn-query-mode=true for services with query-compatible trait" +} \ No newline at end of file diff --git a/codegen/aws-sdk-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/customization/AwsQueryModeCustomization.kt b/codegen/aws-sdk-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/customization/AwsQueryModeCustomization.kt new file mode 100644 index 00000000000..1d97a3c7a08 --- /dev/null +++ b/codegen/aws-sdk-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/customization/AwsQueryModeCustomization.kt @@ -0,0 +1,28 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package aws.sdk.kotlin.codegen.customization + +import software.amazon.smithy.aws.traits.protocols.AwsQueryCompatibleTrait +import software.amazon.smithy.kotlin.codegen.KotlinSettings +import software.amazon.smithy.kotlin.codegen.integration.KotlinIntegration +import software.amazon.smithy.kotlin.codegen.model.hasTrait +import software.amazon.smithy.kotlin.codegen.rendering.protocol.MutateHeadersMiddleware +import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolGenerator +import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolMiddleware +import software.amazon.smithy.model.Model + +/** + * Send an extra `x-amzn-query-mode` header with a value of `true` for services which have the [AwsQueryCompatibleTrait] applied. + */ +class AwsQueryModeCustomization : KotlinIntegration { + override fun enabledForService(model: Model, settings: KotlinSettings): Boolean = + model + .getShape(settings.service) + .get() + .hasTrait() + + override fun customizeMiddleware(ctx: ProtocolGenerator.GenerationContext, resolved: List): List = + resolved + MutateHeadersMiddleware(extraHeaders = mapOf("x-amzn-query-mode" to "true")) +} diff --git a/codegen/aws-sdk-codegen/src/main/resources/META-INF/services/software.amazon.smithy.kotlin.codegen.integration.KotlinIntegration b/codegen/aws-sdk-codegen/src/main/resources/META-INF/services/software.amazon.smithy.kotlin.codegen.integration.KotlinIntegration index 3ef264034df..6d8199d3072 100644 --- a/codegen/aws-sdk-codegen/src/main/resources/META-INF/services/software.amazon.smithy.kotlin.codegen.integration.KotlinIntegration +++ b/codegen/aws-sdk-codegen/src/main/resources/META-INF/services/software.amazon.smithy.kotlin.codegen.integration.KotlinIntegration @@ -43,3 +43,4 @@ aws.sdk.kotlin.codegen.customization.s3.express.S3ExpressIntegration aws.sdk.kotlin.codegen.customization.s3.S3ExpiresIntegration aws.sdk.kotlin.codegen.BusinessMetricsIntegration aws.sdk.kotlin.codegen.smoketests.SmokeTestsDenyListIntegration +aws.sdk.kotlin.codegen.customization.AwsQueryModeCustomization diff --git a/codegen/aws-sdk-codegen/src/test/kotlin/aws/sdk/kotlin/codegen/customization/AwsQueryModeCustomizationTest.kt b/codegen/aws-sdk-codegen/src/test/kotlin/aws/sdk/kotlin/codegen/customization/AwsQueryModeCustomizationTest.kt new file mode 100644 index 00000000000..b8d5e3e67e3 --- /dev/null +++ b/codegen/aws-sdk-codegen/src/test/kotlin/aws/sdk/kotlin/codegen/customization/AwsQueryModeCustomizationTest.kt @@ -0,0 +1,130 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ +package aws.sdk.kotlin.codegen.customization + +import aws.sdk.kotlin.codegen.testutil.lines +import software.amazon.smithy.kotlin.codegen.test.* +import kotlin.test.Test +import kotlin.test.assertFalse +import kotlin.test.assertTrue + +class AwsQueryModeCustomizationTest { + private val queryCompatibleModel = """ + namespace com.test + + use aws.protocols#awsJson1_0 + use aws.protocols#awsQueryCompatible + use aws.api#service + + @awsJson1_0 + @awsQueryCompatible + @service(sdkId: "QueryCompatible") + service QueryCompatible { + version: "1.0.0", + operations: [GetFoo] + } + + @http(method: "POST", uri: "/foo") + operation GetFoo { + input: GetFooInput + } + + structure GetFooInput { + payload: String + } + """ + .trimIndent() + .toSmithyModel() + + private val nonQueryCompatibleModel = """ + namespace com.test + + use aws.protocols#awsJson1_0 + use aws.protocols#awsQueryCompatible + use aws.api#service + + @awsJson1_0 + @service(sdkId: "NonQueryCompatible") + service NonQueryCompatible { + version: "1.0.0", + operations: [GetFoo] + } + + @http(method: "POST", uri: "/foo") + operation GetFoo { + input: GetFooInput + } + + structure GetFooInput { + payload: String + } + """ + .trimIndent() + .toSmithyModel() + + @Test + fun testEnabledForQueryCompatibleModel() { + assertTrue { + AwsQueryModeCustomization() + .enabledForService(queryCompatibleModel, queryCompatibleModel.defaultSettings()) + } + } + + @Test + fun testNotExpectedForNonQueryCompatibleModel() { + assertFalse { + AwsQueryModeCustomization() + .enabledForService(nonQueryCompatibleModel, nonQueryCompatibleModel.defaultSettings()) + } + } + + @Test + fun `x-amzn-query-mode applied`() { + val ctx = queryCompatibleModel.newTestContext("QueryCompatible", integrations = listOf(AwsQueryModeCustomization())) + val generator = MockHttpProtocolGenerator(queryCompatibleModel) + generator.generateProtocolClient(ctx.generationCtx) + + ctx.generationCtx.delegator.finalize() + ctx.generationCtx.delegator.flushWriters() + + val actual = ctx.manifest.expectFileString("/src/main/kotlin/com/test/DefaultTestClient.kt") + + val getFooMethod = actual.lines(" override suspend fun getFoo(input: GetFooRequest): GetFooResponse {", " }") + + val expectedHeaderMutation = """ + op.install( + MutateHeaders().apply { + append("x-amzn-query-mode", "true") + } + ) + """.replaceIndent(" ") + + getFooMethod.shouldContainOnlyOnceWithDiff(expectedHeaderMutation) + } + + @Test + fun `x-amzn-query-mode NOT applied`() { + val ctx = nonQueryCompatibleModel.newTestContext("NonQueryCompatible", integrations = listOf()) + val generator = MockHttpProtocolGenerator(nonQueryCompatibleModel) + generator.generateProtocolClient(ctx.generationCtx) + + ctx.generationCtx.delegator.finalize() + ctx.generationCtx.delegator.flushWriters() + + val actual = ctx.manifest.expectFileString("/src/main/kotlin/com/test/DefaultTestClient.kt") + + val getFooMethod = actual.lines(" override suspend fun getFoo(input: GetFooRequest): GetFooResponse {", " }") + + val unexpectedHeaderMutation = """ + op.install( + MutateHeaders().apply { + append("x-amzn-query-mode", "true") + } + ) + """.replaceIndent(" ") + + getFooMethod.shouldNotContainOnlyOnceWithDiff(unexpectedHeaderMutation) + } +}