Skip to content

Commit c7524fd

Browse files
authored
fix: handle upper/lowercase message field in error structs (#163)
1 parent 50c5921 commit c7524fd

File tree

3 files changed

+108
-2
lines changed

3 files changed

+108
-2
lines changed

codegen/smithy-aws-kotlin-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/protocols/RestXml.kt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ package aws.sdk.kotlin.codegen.protocols
77

88
import aws.sdk.kotlin.codegen.protocols.core.AwsHttpBindingProtocolGenerator
99
import aws.sdk.kotlin.codegen.protocols.xml.RestXmlErrorMiddleware
10+
import aws.sdk.kotlin.codegen.protocols.xml.RestXmlSerdeDescriptorGenerator
1011
import software.amazon.smithy.aws.traits.protocols.RestXmlTrait
1112
import software.amazon.smithy.kotlin.codegen.core.*
1213
import software.amazon.smithy.kotlin.codegen.model.*
@@ -51,7 +52,7 @@ class RestXml : AwsHttpBindingProtocolGenerator() {
5152
val sortedMembers = sortMembersForSerialization(members)
5253

5354
// render the serde descriptors
54-
XmlSerdeDescriptorGenerator(ctx.toRenderingContext(this, shape, writer), sortedMembers).render()
55+
RestXmlSerdeDescriptorGenerator(ctx.toRenderingContext(this, shape, writer), sortedMembers).render()
5556
if (shape.isUnionShape) {
5657
SerializeUnionGenerator(ctx, sortedMembers, writer, defaultTimestampFormat).render()
5758
} else {
@@ -115,7 +116,7 @@ class RestXml : AwsHttpBindingProtocolGenerator() {
115116
members: List<MemberShape>,
116117
writer: KotlinWriter,
117118
) {
118-
XmlSerdeDescriptorGenerator(ctx.toRenderingContext(this, shape, writer), members).render()
119+
RestXmlSerdeDescriptorGenerator(ctx.toRenderingContext(this, shape, writer), members).render()
119120
if (shape.isUnionShape) {
120121
val name = ctx.symbolProvider.toSymbol(shape).name
121122
DeserializeUnionGenerator(ctx, name, members, writer, defaultTimestampFormat).render()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
* SPDX-License-Identifier: Apache-2.0.
4+
*/
5+
6+
package aws.sdk.kotlin.codegen.protocols.xml
7+
8+
import software.amazon.smithy.kotlin.codegen.core.RenderingContext
9+
import software.amazon.smithy.kotlin.codegen.core.RuntimeTypes
10+
import software.amazon.smithy.kotlin.codegen.model.isError
11+
import software.amazon.smithy.kotlin.codegen.rendering.serde.SdkFieldDescriptorTrait
12+
import software.amazon.smithy.kotlin.codegen.rendering.serde.XmlSerdeDescriptorGenerator
13+
import software.amazon.smithy.kotlin.codegen.rendering.serde.add
14+
import software.amazon.smithy.kotlin.codegen.utils.dq
15+
import software.amazon.smithy.kotlin.codegen.utils.toggleFirstCharacterCase
16+
import software.amazon.smithy.model.shapes.MemberShape
17+
import software.amazon.smithy.model.shapes.Shape
18+
19+
/**
20+
* restXml-specific descriptor generator
21+
*/
22+
class RestXmlSerdeDescriptorGenerator(
23+
ctx: RenderingContext<Shape>,
24+
memberShapes: List<MemberShape>? = null
25+
) : XmlSerdeDescriptorGenerator(ctx, memberShapes) {
26+
override fun getFieldDescriptorTraits(
27+
member: MemberShape,
28+
targetShape: Shape,
29+
nameSuffix: String
30+
): List<SdkFieldDescriptorTrait> {
31+
val traitList = super.getFieldDescriptorTraits(member, targetShape, nameSuffix).toMutableList()
32+
33+
if (ctx.shape?.isError == true) {
34+
val serialName = getSerialName(member, nameSuffix)
35+
if (serialName.equals("message", ignoreCase = true)) {
36+
// Need to be able to read error messages from "Message" or "message"
37+
// https://github.com/awslabs/smithy-kotlin/issues/352
38+
traitList.add(RuntimeTypes.Serde.SerdeXml.XmlAliasName, serialName.toggleFirstCharacterCase().dq())
39+
}
40+
}
41+
42+
return traitList
43+
}
44+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
* SPDX-License-Identifier: Apache-2.0.
4+
*/
5+
6+
package aws.sdk.kotlin.codegen.protocols.xml
7+
8+
import org.junit.jupiter.api.Test
9+
import software.amazon.smithy.kotlin.codegen.test.*
10+
import software.amazon.smithy.model.shapes.ShapeId
11+
12+
class RestXmlSerdeDescriptorGeneratorTest {
13+
private fun render(modelSnippet: String): String {
14+
val model = modelSnippet.prependNamespaceAndService().toSmithyModel()
15+
16+
val testCtx = model.newTestContext()
17+
val writer = testCtx.newWriter()
18+
val shape = model.expectShape(ShapeId.from("com.test#Foo"))
19+
val renderingCtx = testCtx.toRenderingContext(writer, shape)
20+
21+
RestXmlSerdeDescriptorGenerator(renderingCtx).render()
22+
return writer.toString()
23+
}
24+
25+
@Test
26+
fun `it should add alias for message field in error struct`() {
27+
val generated = render(
28+
"""
29+
@error("client")
30+
structure Foo {
31+
message: String,
32+
foo: String
33+
}
34+
""".trimIndent()
35+
)
36+
37+
val expectedDescriptors = """
38+
val FOO_DESCRIPTOR = SdkFieldDescriptor(SerialKind.String, XmlSerialName("foo"))
39+
val MESSAGE_DESCRIPTOR = SdkFieldDescriptor(SerialKind.String, XmlSerialName("message"), XmlAliasName("Message"))
40+
""".formatForTest("")
41+
42+
generated.shouldContainOnlyOnceWithDiff(expectedDescriptors)
43+
}
44+
45+
@Test
46+
fun `it should not add alias for message field in non-error struct`() {
47+
val generated = render(
48+
"""
49+
structure Foo {
50+
message: String
51+
}
52+
""".trimIndent()
53+
)
54+
55+
val expectedDescriptors = """
56+
val MESSAGE_DESCRIPTOR = SdkFieldDescriptor(SerialKind.String, XmlSerialName("message"))
57+
""".formatForTest("")
58+
59+
generated.shouldContainOnlyOnceWithDiff(expectedDescriptors)
60+
}
61+
}

0 commit comments

Comments
 (0)