Skip to content

Commit 177f282

Browse files
authored
feat: support initial-request and initial-response event stream types for RPC-bound protocols (#1026)
1 parent 988f392 commit 177f282

File tree

10 files changed

+331
-136
lines changed

10 files changed

+331
-136
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
{
2+
"id": "d6c229e5-12a0-4d32-97d9-a7368a4d71dd",
3+
"type": "feature",
4+
"description": "Support initial-request and initial-response for event streams using RPC-based protocols"
5+
}

codegen/smithy-aws-kotlin-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/customization/RemoveEventStreamOperations.kt

Lines changed: 0 additions & 63 deletions
This file was deleted.

codegen/smithy-aws-kotlin-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/protocols/eventstream/EventStreamParserGenerator.kt

Lines changed: 74 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,27 @@ import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolGenerato
1515
import software.amazon.smithy.kotlin.codegen.rendering.serde.StructuredDataParserGenerator
1616
import software.amazon.smithy.kotlin.codegen.rendering.serde.bodyDeserializer
1717
import software.amazon.smithy.kotlin.codegen.rendering.serde.bodyDeserializerName
18+
import software.amazon.smithy.kotlin.codegen.utils.getOrNull
1819
import software.amazon.smithy.model.shapes.*
1920
import software.amazon.smithy.model.traits.EventHeaderTrait
2021
import software.amazon.smithy.model.traits.EventPayloadTrait
22+
import software.amazon.smithy.model.traits.StreamingTrait
23+
24+
/**
25+
* A set of RPC-bound Smithy protocols
26+
*/
27+
val RPC_BOUND_PROTOCOLS = setOf(
28+
"awsJson1_0",
29+
"awsJson1_1",
30+
"awsQuery",
31+
"ec2Query",
32+
)
33+
34+
/**
35+
* Represents whether the given ShapeId represents an RPC-bound Smithy protocol
36+
*/
37+
internal val ShapeId.isRpcBoundProtocol: Boolean
38+
get() = RPC_BOUND_PROTOCOLS.contains(name)
2139

2240
/**
2341
* Implements rendering deserialize implementation for event streams implemented using the
@@ -60,26 +78,29 @@ class EventStreamParserGenerator(
6078
val streamShape = ctx.model.expectShape<UnionShape>(streamingMember.target)
6179
val streamSymbol = ctx.symbolProvider.toSymbol(streamShape)
6280

63-
// TODO - handle RPC bound protocol bindings where the initial response is bound to an event stream document
64-
// possibly by decoding the first Message
65-
6681
val messageTypeSymbol = RuntimeTypes.AwsEventStream.MessageType
6782
val baseExceptionSymbol = ExceptionBaseClassGenerator.baseExceptionSymbol(ctx.settings)
6883

6984
writer.write("val chan = body.#T() ?: return", RuntimeTypes.Http.toSdkByteReadChannel)
70-
writer.write("val events = #T(chan)", RuntimeTypes.AwsEventStream.decodeFrames)
71-
.indent()
85+
writer.write("val frames = #T(chan)", RuntimeTypes.AwsEventStream.decodeFrames)
86+
if (ctx.protocol.isRpcBoundProtocol) {
87+
renderDeserializeInitialResponse(ctx, output, writer)
88+
} else {
89+
writer.write("val events = frames")
90+
}
91+
writer.indent()
7292
.withBlock(".#T { message ->", "}", RuntimeTypes.KotlinxCoroutines.Flow.map) {
73-
withBlock("when(val mt = message.#T()) {", "}", RuntimeTypes.AwsEventStream.MessageTypeExt) {
74-
withBlock("is #T.Event -> when(mt.shapeType) {", "}", messageTypeSymbol) {
93+
withBlock("when (val mt = message.#T()) {", "}", RuntimeTypes.AwsEventStream.MessageTypeExt) {
94+
withBlock("is #T.Event -> when (mt.shapeType) {", "}", messageTypeSymbol) {
7595
streamShape.filterEventStreamErrors(ctx.model).forEach { member ->
7696
withBlock("#S -> {", "}", member.memberName) {
7797
renderDeserializeEventVariant(ctx, streamSymbol, member, writer)
7898
}
7999
}
100+
80101
write("else -> #T.SdkUnknown", streamSymbol)
81102
}
82-
withBlock("is #T.Exception -> when(mt.shapeType){", "}", messageTypeSymbol) {
103+
withBlock("is #T.Exception -> when (mt.shapeType) {", "}", messageTypeSymbol) {
83104
// errors are completely bound to payload (at least according to design docs)
84105
val errorMembers = streamShape.members().filter {
85106
val target = ctx.model.expectShape(it.target)
@@ -101,7 +122,9 @@ class EventStreamParserGenerator(
101122
}
102123
}
103124
.dedent()
104-
.write("builder.#L = events", streamingMember.defaultName())
125+
.write("")
126+
127+
writer.write("builder.#L = events", streamingMember.defaultName())
105128
}
106129

107130
private fun renderDeserializeEventVariant(ctx: ProtocolGenerator.GenerationContext, unionSymbol: Symbol, member: MemberShape, writer: KotlinWriter) {
@@ -167,6 +190,48 @@ class EventStreamParserGenerator(
167190
writer.write("#T.#L(e)", unionSymbol, member.unionVariantName())
168191
}
169192

193+
/**
194+
* Renders deserialization logic for a message with the `initial-response` type.
195+
*/
196+
private fun renderDeserializeInitialResponse(ctx: ProtocolGenerator.GenerationContext, outputShape: StructureShape, writer: KotlinWriter) {
197+
// A custom function which only deserializes the initial response members
198+
val initialResponseDeserializeFn = sdg.payloadDeserializer(ctx, outputShape, outputShape.initialResponseMembers)
199+
200+
writer.write(
201+
"val firstMessage = frames.#T(1).#T()",
202+
RuntimeTypes.KotlinxCoroutines.Flow.take,
203+
RuntimeTypes.KotlinxCoroutines.Flow.single,
204+
)
205+
writer.write("val firstMessageType = firstMessage.type()")
206+
writer.openBlock(
207+
"val events = if (firstMessageType is #T.Event && firstMessageType.shapeType == \"initial-response\") {",
208+
RuntimeTypes.AwsEventStream.MessageType,
209+
)
210+
// Deserialize into `initialResponse`, then apply it to the actual response builder
211+
writer.write("val initialResponse = #T(firstMessage.payload)", initialResponseDeserializeFn)
212+
writer.withBlock("builder.apply {", "}") {
213+
outputShape.initialResponseMembers.forEach { member ->
214+
writer.write("#1L = initialResponse.#1L", member.defaultName())
215+
}
216+
}
217+
.write("frames")
218+
.closeAndOpenBlock("} else {")
219+
.write(
220+
"#T(#T(firstMessage), frames)",
221+
RuntimeTypes.KotlinxCoroutines.Flow.merge,
222+
RuntimeTypes.KotlinxCoroutines.Flow.flowOf,
223+
)
224+
.closeBlock("}")
225+
}
226+
227+
/**
228+
* Get all the shape's members which aren't an event stream
229+
*/
230+
private val StructureShape.initialResponseMembers get() = members().filter {
231+
val targetShape = ctx.model.getShape(it.target).getOrNull()
232+
targetShape?.hasTrait<StreamingTrait>() == false
233+
}
234+
170235
private fun renderDeserializeExplicitEventPayloadMember(
171236
ctx: ProtocolGenerator.GenerationContext,
172237
member: MemberShape,

codegen/smithy-aws-kotlin-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/protocols/eventstream/EventStreamSerializerGenerator.kt

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@ import software.amazon.smithy.kotlin.codegen.lang.KotlinTypes
1212
import software.amazon.smithy.kotlin.codegen.model.*
1313
import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolGenerator
1414
import software.amazon.smithy.kotlin.codegen.rendering.serde.*
15+
import software.amazon.smithy.kotlin.codegen.utils.getOrNull
1516
import software.amazon.smithy.model.shapes.*
1617
import software.amazon.smithy.model.traits.EventHeaderTrait
1718
import software.amazon.smithy.model.traits.EventPayloadTrait
19+
import software.amazon.smithy.model.traits.StreamingTrait
1820

1921
/**
2022
* Implements rendering serialize implementation for event streams implemented using the
@@ -76,10 +78,35 @@ class EventStreamSerializerGenerator(
7678
)
7779

7880
val encodeFn = encodeEventStreamMessage(ctx, op, streamShape)
79-
writer.withBlock("val messages = stream", "") {
80-
write(".#T(::#T)", RuntimeTypes.KotlinxCoroutines.Flow.map, encodeFn)
81-
write(".#T(context)", RuntimeTypes.AwsEventStream.sign)
82-
write(".#T()", RuntimeTypes.AwsEventStream.encode)
81+
82+
writer.write("")
83+
val initialRequestMembers = input.initialRequestMembers(ctx)
84+
if (ctx.protocol.isRpcBoundProtocol && initialRequestMembers.isNotEmpty()) {
85+
val serializerFn = sdg.payloadSerializer(ctx, input, initialRequestMembers)
86+
87+
writer.withBlock("val initialRequest = buildMessage {", "}") {
88+
writer.write("addHeader(\":message-type\", HeaderValue.String(\"event\"))")
89+
writer.write("addHeader(\":event-type\", HeaderValue.String(\"initial-request\"))")
90+
writer.write("payload = #T(input)", serializerFn)
91+
}
92+
93+
writer.withBlock(
94+
"val messages = #T(#T(initialRequest), stream.#T(::#T))",
95+
"",
96+
RuntimeTypes.KotlinxCoroutines.Flow.merge,
97+
RuntimeTypes.KotlinxCoroutines.Flow.flowOf,
98+
RuntimeTypes.KotlinxCoroutines.Flow.map,
99+
encodeFn,
100+
) {
101+
write(".#T(context)", RuntimeTypes.AwsEventStream.sign)
102+
write(".#T()", RuntimeTypes.AwsEventStream.encode)
103+
}
104+
} else {
105+
writer.withBlock("val messages = stream", "") {
106+
write(".#T(::#T)", RuntimeTypes.KotlinxCoroutines.Flow.map, encodeFn)
107+
write(".#T(context)", RuntimeTypes.AwsEventStream.sign)
108+
write(".#T()", RuntimeTypes.AwsEventStream.encode)
109+
}
83110
}
84111

85112
writer.write("")
@@ -195,4 +222,12 @@ class EventStreamSerializerGenerator(
195222
private fun KotlinWriter.addStringHeader(name: String, value: String) {
196223
write("addHeader(#S, #T.String(#S))", name, RuntimeTypes.AwsEventStream.HeaderValue, value)
197224
}
225+
226+
/**
227+
* Get all the shape's members which aren't an event stream
228+
*/
229+
private fun StructureShape.initialRequestMembers(ctx: ProtocolGenerator.GenerationContext) = members().filter {
230+
val targetShape = ctx.model.getShape(it.target).getOrNull()
231+
targetShape?.hasTrait<StreamingTrait>() == false
232+
}
198233
}

codegen/smithy-aws-kotlin-codegen/src/main/resources/META-INF/services/software.amazon.smithy.kotlin.codegen.integration.KotlinIntegration

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ aws.sdk.kotlin.codegen.customization.BoxServices
1616
aws.sdk.kotlin.codegen.customization.glacier.GlacierBodyChecksum
1717
aws.sdk.kotlin.codegen.customization.machinelearning.MachineLearningEndpointCustomization
1818
aws.sdk.kotlin.codegen.customization.BackfillOptionalAuth
19-
aws.sdk.kotlin.codegen.customization.RemoveEventStreamOperations
2019
aws.sdk.kotlin.codegen.customization.flexiblechecksums.FlexibleChecksumsRequest
2120
aws.sdk.kotlin.codegen.customization.flexiblechecksums.FlexibleChecksumsResponse
2221
aws.sdk.kotlin.codegen.customization.route53.TrimResourcePrefix

codegen/smithy-aws-kotlin-codegen/src/test/kotlin/aws/sdk/kotlin/codegen/customization/RemoveEventStreamOperationsTest.kt

Lines changed: 0 additions & 46 deletions
This file was deleted.

tests/codegen/event-stream/build.gradle.kts

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,24 +19,27 @@ dependencies {
1919
data class EventStreamTest(
2020
val projectionName: String,
2121
val protocolName: String,
22+
val modelTemplate: File,
2223
) {
2324
val model: File
2425
get() = buildDir.resolve("$projectionName/model.smithy")
2526
}
2627

2728
val tests = listOf(
28-
EventStreamTest("restJson1", "restJson1"),
29+
EventStreamTest("restJson1", "restJson1", file("event-stream-model-template.smithy")),
30+
EventStreamTest("awsJson11", "awsJson1_1", file("event-stream-initial-request-response.smithy")),
2931
)
3032

31-
fun fillInModel(output: File, protocolName: String) {
32-
val template = file("event-stream-model-template.smithy")
33+
fun fillInModel(output: File, protocolName: String, template: File) {
3334
val input = template.readText()
3435
val opTraits = when (protocolName) {
3536
"restJson1", "restXml" -> """@http(method: "POST", uri: "/test-eventstream", code: 200)"""
3637
else -> ""
3738
}
38-
val replaced = input.replace("{protocol-name}", protocolName)
39+
val replaced = input
40+
.replace("{protocol-name}", protocolName)
3941
.replace("{op-traits}", opTraits)
42+
4043
output.parentFile.mkdirs()
4144
output.writeText(replaced)
4245
}
@@ -60,7 +63,7 @@ codegen {
6063

6164
smithyKotlinPlugin {
6265
serviceShapeId = testServiceShapeId
63-
packageName = "aws.sdk.kotlin.test.eventstream.${test.protocolName.toLowerCase()}"
66+
packageName = "aws.sdk.kotlin.test.eventstream.${test.projectionName.toLowerCase()}"
6467
packageVersion = "1.0"
6568
buildSettings {
6669
generateFullProject = false
@@ -77,7 +80,7 @@ codegen {
7780

7881
tasks.named("generateSmithyBuildConfig") {
7982
doFirst {
80-
tests.forEach { test -> fillInModel(test.model, test.protocolName) }
83+
tests.forEach { test -> fillInModel(test.model, test.protocolName, test.modelTemplate) }
8184
}
8285
}
8386

0 commit comments

Comments
 (0)