@@ -15,9 +15,27 @@ import software.amazon.smithy.kotlin.codegen.rendering.protocol.ProtocolGenerato
1515import software.amazon.smithy.kotlin.codegen.rendering.serde.StructuredDataParserGenerator
1616import software.amazon.smithy.kotlin.codegen.rendering.serde.bodyDeserializer
1717import software.amazon.smithy.kotlin.codegen.rendering.serde.bodyDeserializerName
18+ import software.amazon.smithy.kotlin.codegen.utils.getOrNull
1819import software.amazon.smithy.model.shapes.*
1920import software.amazon.smithy.model.traits.EventHeaderTrait
2021import software.amazon.smithy.model.traits.EventPayloadTrait
22+ import software.amazon.smithy.model.traits.StreamingTrait
23+
24+ /* *
25+ * A set of RPC-bound Smithy protocols
26+ */
27+ val RPC_BOUND_PROTOCOLS = setOf (
28+ " awsJson1_0" ,
29+ " awsJson1_1" ,
30+ " awsQuery" ,
31+ " ec2Query" ,
32+ )
33+
34+ /* *
35+ * Represents whether the given ShapeId represents an RPC-bound Smithy protocol
36+ */
37+ internal val ShapeId .isRpcBoundProtocol: Boolean
38+ get() = RPC_BOUND_PROTOCOLS .contains(name)
2139
2240/* *
2341 * Implements rendering deserialize implementation for event streams implemented using the
@@ -60,26 +78,29 @@ class EventStreamParserGenerator(
6078 val streamShape = ctx.model.expectShape<UnionShape >(streamingMember.target)
6179 val streamSymbol = ctx.symbolProvider.toSymbol(streamShape)
6280
63- // TODO - handle RPC bound protocol bindings where the initial response is bound to an event stream document
64- // possibly by decoding the first Message
65-
6681 val messageTypeSymbol = RuntimeTypes .AwsEventStream .MessageType
6782 val baseExceptionSymbol = ExceptionBaseClassGenerator .baseExceptionSymbol(ctx.settings)
6883
6984 writer.write(" val chan = body.#T() ?: return" , RuntimeTypes .Http .toSdkByteReadChannel)
70- writer.write(" val events = #T(chan)" , RuntimeTypes .AwsEventStream .decodeFrames)
71- .indent()
85+ writer.write(" val frames = #T(chan)" , RuntimeTypes .AwsEventStream .decodeFrames)
86+ if (ctx.protocol.isRpcBoundProtocol) {
87+ renderDeserializeInitialResponse(ctx, output, writer)
88+ } else {
89+ writer.write(" val events = frames" )
90+ }
91+ writer.indent()
7292 .withBlock(" .#T { message ->" , " }" , RuntimeTypes .KotlinxCoroutines .Flow .map) {
73- withBlock(" when(val mt = message.#T()) {" , " }" , RuntimeTypes .AwsEventStream .MessageTypeExt ) {
74- withBlock(" is #T.Event -> when(mt.shapeType) {" , " }" , messageTypeSymbol) {
93+ withBlock(" when (val mt = message.#T()) {" , " }" , RuntimeTypes .AwsEventStream .MessageTypeExt ) {
94+ withBlock(" is #T.Event -> when (mt.shapeType) {" , " }" , messageTypeSymbol) {
7595 streamShape.filterEventStreamErrors(ctx.model).forEach { member ->
7696 withBlock(" #S -> {" , " }" , member.memberName) {
7797 renderDeserializeEventVariant(ctx, streamSymbol, member, writer)
7898 }
7999 }
100+
80101 write(" else -> #T.SdkUnknown" , streamSymbol)
81102 }
82- withBlock(" is #T.Exception -> when(mt.shapeType){" , " }" , messageTypeSymbol) {
103+ withBlock(" is #T.Exception -> when (mt.shapeType) {" , " }" , messageTypeSymbol) {
83104 // errors are completely bound to payload (at least according to design docs)
84105 val errorMembers = streamShape.members().filter {
85106 val target = ctx.model.expectShape(it.target)
@@ -101,7 +122,9 @@ class EventStreamParserGenerator(
101122 }
102123 }
103124 .dedent()
104- .write(" builder.#L = events" , streamingMember.defaultName())
125+ .write(" " )
126+
127+ writer.write(" builder.#L = events" , streamingMember.defaultName())
105128 }
106129
107130 private fun renderDeserializeEventVariant (ctx : ProtocolGenerator .GenerationContext , unionSymbol : Symbol , member : MemberShape , writer : KotlinWriter ) {
@@ -167,6 +190,48 @@ class EventStreamParserGenerator(
167190 writer.write(" #T.#L(e)" , unionSymbol, member.unionVariantName())
168191 }
169192
193+ /* *
194+ * Renders deserialization logic for a message with the `initial-response` type.
195+ */
196+ private fun renderDeserializeInitialResponse (ctx : ProtocolGenerator .GenerationContext , outputShape : StructureShape , writer : KotlinWriter ) {
197+ // A custom function which only deserializes the initial response members
198+ val initialResponseDeserializeFn = sdg.payloadDeserializer(ctx, outputShape, outputShape.initialResponseMembers)
199+
200+ writer.write(
201+ " val firstMessage = frames.#T(1).#T()" ,
202+ RuntimeTypes .KotlinxCoroutines .Flow .take,
203+ RuntimeTypes .KotlinxCoroutines .Flow .single,
204+ )
205+ writer.write(" val firstMessageType = firstMessage.type()" )
206+ writer.openBlock(
207+ " val events = if (firstMessageType is #T.Event && firstMessageType.shapeType == \" initial-response\" ) {" ,
208+ RuntimeTypes .AwsEventStream .MessageType ,
209+ )
210+ // Deserialize into `initialResponse`, then apply it to the actual response builder
211+ writer.write(" val initialResponse = #T(firstMessage.payload)" , initialResponseDeserializeFn)
212+ writer.withBlock(" builder.apply {" , " }" ) {
213+ outputShape.initialResponseMembers.forEach { member ->
214+ writer.write(" #1L = initialResponse.#1L" , member.defaultName())
215+ }
216+ }
217+ .write(" frames" )
218+ .closeAndOpenBlock(" } else {" )
219+ .write(
220+ " #T(#T(firstMessage), frames)" ,
221+ RuntimeTypes .KotlinxCoroutines .Flow .merge,
222+ RuntimeTypes .KotlinxCoroutines .Flow .flowOf,
223+ )
224+ .closeBlock(" }" )
225+ }
226+
227+ /* *
228+ * Get all the shape's members which aren't an event stream
229+ */
230+ private val StructureShape .initialResponseMembers get() = members().filter {
231+ val targetShape = ctx.model.getShape(it.target).getOrNull()
232+ targetShape?.hasTrait<StreamingTrait >() == false
233+ }
234+
170235 private fun renderDeserializeExplicitEventPayloadMember (
171236 ctx : ProtocolGenerator .GenerationContext ,
172237 member : MemberShape ,
0 commit comments