Skip to content

Commit 258c634

Browse files
authored
misc: add Kinesis SubscribeToShard E2E test (#1030)
1 parent 0b5ce1f commit 258c634

File tree

7 files changed

+246
-21
lines changed

7 files changed

+246
-21
lines changed
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
{
2+
"id": "316c80e0-c526-4a42-8fc0-165aebd7581b",
3+
"type": "bugfix",
4+
"description": "Fix closing an event stream causing an IllegalStateException",
5+
"issues": [
6+
"https://github.com/awslabs/smithy-kotlin/issues/935"
7+
]
8+
}

codegen/smithy-aws-kotlin-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/protocols/eventstream/EventStreamParserGenerator.kt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,11 @@ class EventStreamParserGenerator(
6262
// we just need to deserialize the event stream member (and/or the initial response)
6363
writer.withBlock(
6464
// FIXME - revert to private, exposed as internal temporarily while we figure out integration tests
65-
"internal suspend fun #L(builder: #T.Builder, body: #T) {",
65+
"internal suspend fun #L(builder: #T.Builder, call: #T) {",
6666
"}",
6767
op.bodyDeserializerName(),
6868
outputSymbol,
69-
RuntimeTypes.Http.HttpBody,
69+
RuntimeTypes.Http.HttpCall,
7070
) {
7171
renderDeserializeEventStream(ctx, op, writer)
7272
}
@@ -81,7 +81,7 @@ class EventStreamParserGenerator(
8181
val messageTypeSymbol = RuntimeTypes.AwsEventStream.MessageType
8282
val baseExceptionSymbol = ExceptionBaseClassGenerator.baseExceptionSymbol(ctx.settings)
8383

84-
writer.write("val chan = body.#T() ?: return", RuntimeTypes.Http.toSdkByteReadChannel)
84+
writer.write("val chan = call.response.body.#T(call) ?: return", RuntimeTypes.Http.toSdkByteReadChannel)
8585
writer.write("val frames = #T(chan)", RuntimeTypes.AwsEventStream.decodeFrames)
8686
if (ctx.protocol.isRpcBoundProtocol) {
8787
renderDeserializeInitialResponse(ctx, output, writer)
Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package aws.sdk.kotlin.services.kinesis
6+
7+
import aws.sdk.kotlin.services.kinesis.model.*
8+
import aws.sdk.kotlin.services.kinesis.waiters.waitUntilStreamExists
9+
import aws.sdk.kotlin.testing.withAllEngines
10+
import aws.smithy.kotlin.runtime.retries.getOrThrow
11+
import kotlinx.coroutines.*
12+
import kotlinx.coroutines.flow.first
13+
import org.junit.jupiter.api.AfterAll
14+
import org.junit.jupiter.api.BeforeAll
15+
import org.junit.jupiter.api.TestInstance
16+
import java.util.*
17+
import kotlin.test.Test
18+
import kotlin.test.assertEquals
19+
import kotlin.time.Duration.Companion.seconds
20+
21+
/**
22+
* Tests for Kinesis SubscribeToShard (an RPC-bound protocol)
23+
*/
24+
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
25+
class KinesisSubscribeToShardTest {
26+
private val client = KinesisClient { region = "us-east-1" }
27+
private val WAIT_TIMEOUT = 30.seconds
28+
private val POLLING_RATE = 3.seconds
29+
30+
private val STREAM_NAME_PREFIX = "aws-sdk-kotlin-e2e-test-stream-"
31+
private val STREAM_CONSUMER_NAME_PREFIX = "aws-sdk-kotlin-e2e-test-"
32+
33+
private val TEST_DATA = "Bees, bees, bees, bees!"
34+
35+
private lateinit var dataStreamArn: String
36+
private lateinit var dataStreamConsumerArn: String
37+
38+
/**
39+
* Create infrastructure required for the test, if it doesn't exist already.
40+
*/
41+
@BeforeAll
42+
fun setup(): Unit = runBlocking {
43+
dataStreamArn = client.getOrCreateStream()
44+
dataStreamConsumerArn = client.getOrRegisterStreamConsumer()
45+
}
46+
47+
/**
48+
* Delete infrastructure used for the test.
49+
*/
50+
@AfterAll
51+
fun cleanUp(): Unit = runBlocking {
52+
client.deregisterStreamConsumer {
53+
streamArn = dataStreamArn
54+
consumerArn = dataStreamConsumerArn
55+
}
56+
57+
client.deleteStream {
58+
streamArn = dataStreamArn
59+
}
60+
}
61+
62+
/**
63+
* Select the single shard ID associated with the data stream, and subscribe to it.
64+
* Read one event and make sure the data matches what's expected.
65+
*/
66+
@Test
67+
fun testSubscribeToShard(): Unit = runBlocking {
68+
val dataStreamShardId = client.listShards {
69+
streamArn = dataStreamArn
70+
}.shards?.single()!!.shardId
71+
72+
withAllEngines { engine ->
73+
client.withConfig {
74+
httpClient = engine
75+
}.use { clientWithTestEngine ->
76+
clientWithTestEngine.subscribeToShard(
77+
SubscribeToShardRequest {
78+
consumerArn = dataStreamConsumerArn
79+
shardId = dataStreamShardId
80+
startingPosition = StartingPosition {
81+
type = ShardIteratorType.TrimHorizon
82+
}
83+
},
84+
) {
85+
val event = it.eventStream?.first()
86+
val record = event?.asSubscribeToShardEvent()?.records?.single()
87+
assertEquals(TEST_DATA, record?.data?.decodeToString())
88+
}
89+
90+
// Wait 5 seconds, otherwise a ResourceInUseException gets thrown. Source:
91+
// https://docs.aws.amazon.com/kinesis/latest/APIReference/API_SubscribeToShard.html
92+
// > If you call SubscribeToShard 5 seconds or more after a successful call, the second call takes over the subscription
93+
delay(5.seconds)
94+
}
95+
}
96+
}
97+
98+
/**
99+
* Get a Kinesis data stream with the [STREAM_NAME_PREFIX], or if one does not exist,
100+
* create one and populate it with one test record.
101+
* @return the ARN of the data stream
102+
*/
103+
private suspend fun KinesisClient.getOrCreateStream(): String =
104+
listStreams { }
105+
.streamSummaries
106+
?.find { it.streamName?.startsWith(STREAM_NAME_PREFIX) ?: false }
107+
?.streamArn ?: run {
108+
// Create a new data stream, then wait for it to be active
109+
val randomStreamName = STREAM_NAME_PREFIX + UUID.randomUUID()
110+
createStream {
111+
streamName = randomStreamName
112+
shardCount = 1
113+
}
114+
115+
val newStreamArn = waitUntilStreamExists({ streamName = randomStreamName })
116+
.getOrThrow()
117+
.streamDescription!!
118+
.streamArn!!
119+
120+
// Put a record, then wait for it to appear on the stream
121+
putRecord {
122+
data = TEST_DATA.encodeToByteArray()
123+
streamArn = newStreamArn
124+
partitionKey = "Goodbye"
125+
}
126+
127+
val newStreamShardId = client.listShards {
128+
streamArn = newStreamArn
129+
}.shards?.single()!!.shardId
130+
131+
val currentShardIterator = getShardIterator {
132+
shardId = newStreamShardId
133+
shardIteratorType = ShardIteratorType.TrimHorizon
134+
streamArn = newStreamArn
135+
}.shardIterator!!
136+
137+
waitForResource {
138+
getRecords {
139+
shardIterator = currentShardIterator
140+
streamArn = newStreamArn
141+
}.records
142+
?.firstOrNull { it.data?.decodeToString() == TEST_DATA }
143+
}
144+
145+
newStreamArn
146+
}
147+
148+
/**
149+
* Get a Kinesis data stream consumer, or if it doesn't exist, register a new one.
150+
* @return the ARN of the stream consumer
151+
*/
152+
private suspend fun KinesisClient.getOrRegisterStreamConsumer(): String =
153+
listStreamConsumers { streamArn = dataStreamArn }
154+
.consumers
155+
?.firstOrNull { it.consumerName?.startsWith(STREAM_CONSUMER_NAME_PREFIX) ?: false }
156+
?.consumerArn ?: run {
157+
// Register a new consumer and wait for it to be active
158+
159+
val randomConsumerName = STREAM_CONSUMER_NAME_PREFIX + UUID.randomUUID()
160+
registerStreamConsumer {
161+
consumerName = randomConsumerName
162+
streamArn = dataStreamArn
163+
}
164+
165+
waitForResource {
166+
listStreamConsumers { streamArn = dataStreamArn }
167+
?.consumers
168+
?.firstOrNull { it.consumerName == randomConsumerName }
169+
?.takeIf { it.consumerStatus == ConsumerStatus.Active }
170+
?.consumerArn
171+
}
172+
}
173+
174+
/**
175+
* Poll at a predefined [POLLING_RATE] for a resource to exist and return it.
176+
* Throws an exception if this takes longer than the [WAIT_TIMEOUT] duration.
177+
*
178+
* @param getResource a suspending function which returns the resource or null if it does not exist yet
179+
* @return the resource
180+
*/
181+
private suspend fun <T> KinesisClient.waitForResource(getResource: suspend () -> T?): T = withTimeout(WAIT_TIMEOUT) {
182+
var resource: T? = null
183+
while (resource == null) {
184+
resource = getResource()
185+
resource ?: run {
186+
delay(POLLING_RATE)
187+
yield()
188+
}
189+
}
190+
return@withTimeout resource
191+
}
192+
}

services/route53/common/test/aws/sdk/kotlin/services/route53/internal/ChangeResourceRecordSetsUnmarshallingTest.kt

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ class ChangeResourceRecordSetsUnmarshallingTest {
3333
""".trimIndent()
3434

3535
val response: HttpResponse = HttpResponse(
36-
HttpStatusCode(400, "Bad Request"),
37-
Headers.invoke { },
36+
HttpStatusCode.BadRequest,
37+
Headers.Empty,
3838
HttpBody.fromBytes(bodyText.encodeToByteArray()),
3939
)
4040

@@ -62,8 +62,8 @@ class ChangeResourceRecordSetsUnmarshallingTest {
6262
""".trimIndent()
6363

6464
val response: HttpResponse = HttpResponse(
65-
HttpStatusCode(400, "Bad Request"),
66-
Headers.invoke { },
65+
HttpStatusCode.BadRequest,
66+
Headers.Empty,
6767
HttpBody.fromBytes(bodyText.encodeToByteArray()),
6868
)
6969

@@ -92,8 +92,8 @@ class ChangeResourceRecordSetsUnmarshallingTest {
9292
""".trimIndent()
9393

9494
val response: HttpResponse = HttpResponse(
95-
HttpStatusCode(400, "Bad Request"),
96-
Headers.invoke { },
95+
HttpStatusCode.BadRequest,
96+
Headers.Empty,
9797
HttpBody.fromBytes(bodyText.encodeToByteArray()),
9898
)
9999

@@ -123,8 +123,8 @@ class ChangeResourceRecordSetsUnmarshallingTest {
123123
""".trimIndent()
124124

125125
val response: HttpResponse = HttpResponse(
126-
HttpStatusCode(400, "Bad Request"),
127-
Headers.invoke { },
126+
HttpStatusCode.BadRequest,
127+
Headers.Empty,
128128
HttpBody.fromBytes(bodyText.encodeToByteArray()),
129129
)
130130

services/s3/common/test/aws/sdk/kotlin/services/s3/internal/GetBucketLocationOperationDeserializerTest.kt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ class GetBucketLocationOperationDeserializerTest {
3030
""".trimIndent()
3131

3232
val response: HttpResponse = HttpResponse(
33-
HttpStatusCode(200, "Success"),
34-
Headers.invoke { },
33+
HttpStatusCode.OK,
34+
Headers.Empty,
3535
HttpBody.fromBytes(responseXML.encodeToByteArray()),
3636
)
3737

@@ -55,8 +55,8 @@ class GetBucketLocationOperationDeserializerTest {
5555
""".trimIndent()
5656

5757
val response: HttpResponse = HttpResponse(
58-
HttpStatusCode(400, "Bad Request"),
59-
Headers.invoke { },
58+
HttpStatusCode.BadRequest,
59+
Headers.Empty,
6060
HttpBody.fromBytes(responseXML.encodeToByteArray()),
6161
)
6262

tests/codegen/event-stream/src/test/kotlin/HttpEventStreamTests.kt

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,12 @@ import aws.smithy.kotlin.runtime.auth.awssigning.AwsSigningAttributes
1313
import aws.smithy.kotlin.runtime.auth.awssigning.DefaultAwsSigner
1414
import aws.smithy.kotlin.runtime.auth.awssigning.HashSpecification
1515
import aws.smithy.kotlin.runtime.awsprotocol.eventstream.*
16+
import aws.smithy.kotlin.runtime.http.Headers
1617
import aws.smithy.kotlin.runtime.http.HttpBody
17-
import aws.smithy.kotlin.runtime.http.content.ByteArrayContent
18+
import aws.smithy.kotlin.runtime.http.HttpCall
19+
import aws.smithy.kotlin.runtime.http.HttpStatusCode
20+
import aws.smithy.kotlin.runtime.http.request.HttpRequestBuilder
21+
import aws.smithy.kotlin.runtime.http.response.HttpResponse
1822
import aws.smithy.kotlin.runtime.io.SdkBuffer
1923
import aws.smithy.kotlin.runtime.operation.ExecutionContext
2024
import aws.smithy.kotlin.runtime.smithy.test.assertJsonStringsEqual
@@ -69,10 +73,17 @@ class HttpEventStreamTests {
6973
private suspend fun deserializedEvent(message: Message): TestStream {
7074
val buffer = SdkBuffer()
7175
message.encode(buffer)
72-
val body = ByteArrayContent(buffer.readByteArray())
76+
77+
val response = HttpResponse(
78+
HttpStatusCode.OK,
79+
Headers.Empty,
80+
HttpBody.fromBytes(buffer.readByteArray()),
81+
)
82+
val call = HttpCall(HttpRequestBuilder().build(), response)
83+
7384
val builder = TestStreamOpResponse.Builder()
7485

75-
deserializeTestStreamOpOperationBody(builder, body)
86+
deserializeTestStreamOpOperationBody(builder, call)
7687

7788
val resp = builder.build()
7889
checkNotNull(resp.value)

tests/codegen/event-stream/src/test/kotlin/RpcEventStreamTests.kt

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,12 @@ import aws.smithy.kotlin.runtime.auth.awssigning.AwsSigningAttributes
1515
import aws.smithy.kotlin.runtime.auth.awssigning.DefaultAwsSigner
1616
import aws.smithy.kotlin.runtime.auth.awssigning.HashSpecification
1717
import aws.smithy.kotlin.runtime.awsprotocol.eventstream.*
18+
import aws.smithy.kotlin.runtime.http.Headers
1819
import aws.smithy.kotlin.runtime.http.HttpBody
20+
import aws.smithy.kotlin.runtime.http.HttpCall
21+
import aws.smithy.kotlin.runtime.http.HttpStatusCode
22+
import aws.smithy.kotlin.runtime.http.request.HttpRequestBuilder
23+
import aws.smithy.kotlin.runtime.http.response.HttpResponse
1924
import aws.smithy.kotlin.runtime.io.SdkBuffer
2025
import aws.smithy.kotlin.runtime.operation.ExecutionContext
2126
import aws.smithy.kotlin.runtime.util.get
@@ -90,9 +95,9 @@ class RpcEventStreamTests {
9095
val responseBody = flowOf(initialResponseMessage, eventStreamResponse)
9196
.encode()
9297
.asEventStreamHttpBody(this)
93-
9498
val builder = TestStreamOperationWithInitialRequestResponseResponse.Builder()
95-
deserializeTestStreamOperationWithInitialRequestResponseOperationBody(builder, responseBody)
99+
100+
deserializeTestStreamOperationWithInitialRequestResponseOperationBody(builder, responseBody.asHttpCall())
96101

97102
assertEquals(builder.initial, initialResponseData)
98103
val event = builder.value?.single() // this throws an exception if there's not exactly 1 event
@@ -118,7 +123,7 @@ class RpcEventStreamTests {
118123
.asEventStreamHttpBody(this)
119124

120125
val builder = TestStreamOperationWithInitialRequestResponseResponse.Builder()
121-
deserializeTestStreamOperationWithInitialRequestResponseOperationBody(builder, responseBody)
126+
deserializeTestStreamOperationWithInitialRequestResponseOperationBody(builder, responseBody.asHttpCall())
122127

123128
assertNull(builder.initial)
124129
val event = builder.value?.single()
@@ -159,4 +164,13 @@ class RpcEventStreamTests {
159164

160165
return frames
161166
}
167+
168+
private fun HttpBody.asHttpCall(): HttpCall {
169+
val response = HttpResponse(
170+
HttpStatusCode.OK,
171+
Headers.Empty,
172+
this,
173+
)
174+
return HttpCall(HttpRequestBuilder().build(), response)
175+
}
162176
}

0 commit comments

Comments
 (0)