Skip to content

Commit 83a9703

Browse files
authored
feat: add event stream support (#692)
1 parent c451f0a commit 83a9703

File tree

11 files changed

+179
-22
lines changed

11 files changed

+179
-22
lines changed
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
{
2+
"id": "34fc4fc2-ad2d-4264-976c-5013324c10b7",
3+
"type": "feature",
4+
"description": "Add support for event streams",
5+
"issues": [
6+
"awslabs/aws-sdk-kotlin#543"
7+
]
8+
}

codegen/sdk/build.gradle.kts

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,6 @@ data class AwsService(
9595

9696

9797
val disabledServices = setOf(
98-
// Only contains event streams
99-
"transcribe-streaming",
10098
// timestream requires endpoint discovery
10199
// https://github.com/awslabs/smithy-kotlin/issues/146
102100
"timestream-write",

codegen/smithy-aws-kotlin-codegen/src/main/kotlin/aws/sdk/kotlin/codegen/customization/RemoveEventStreamOperations.kt

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,16 @@
55

66
package aws.sdk.kotlin.codegen.customization
77

8+
import software.amazon.smithy.aws.traits.protocols.RestJson1Trait
9+
import software.amazon.smithy.aws.traits.protocols.RestXmlTrait
810
import software.amazon.smithy.kotlin.codegen.KotlinSettings
911
import software.amazon.smithy.kotlin.codegen.integration.KotlinIntegration
1012
import software.amazon.smithy.kotlin.codegen.model.expectShape
1113
import software.amazon.smithy.kotlin.codegen.model.findStreamingMember
1214
import software.amazon.smithy.kotlin.codegen.utils.getOrNull
1315
import software.amazon.smithy.model.Model
16+
import software.amazon.smithy.model.knowledge.ServiceIndex
1417
import software.amazon.smithy.model.shapes.OperationShape
15-
import software.amazon.smithy.model.shapes.ShapeId
1618
import software.amazon.smithy.model.shapes.StructureShape
1719
import software.amazon.smithy.model.transform.ModelTransformer
1820
import java.util.logging.Logger
@@ -25,13 +27,18 @@ class RemoveEventStreamOperations : KotlinIntegration {
2527
override val order: Byte = -127
2628
private val logger = Logger.getLogger(javaClass.name)
2729

28-
private val supportedServiceIds = setOf(
29-
// integration tests
30-
"aws.sdk.kotlin.test.eventstream#TestService",
31-
).map(ShapeId::from).toSet()
30+
private val supportedProtocols = setOf(
31+
RestXmlTrait.ID,
32+
RestJson1Trait.ID,
33+
)
34+
override fun enabledForService(model: Model, settings: KotlinSettings): Boolean {
35+
val serviceIndex = ServiceIndex(model)
36+
val protocols = serviceIndex.getProtocols(settings.service)
37+
.values
38+
.map { it.toShapeId() }
3239

33-
override fun enabledForService(model: Model, settings: KotlinSettings): Boolean =
34-
settings.service !in supportedServiceIds
40+
return protocols.any { it !in supportedProtocols }
41+
}
3542

3643
override fun preprocessModel(model: Model, settings: KotlinSettings): Model =
3744
ModelTransformer.create().filterShapes(model) { parentShape ->

services/build.gradle.kts

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,15 +77,14 @@ subprojects {
7777

7878
if (project.file("e2eTest").exists()) {
7979
jvm().compilations {
80-
val main by getting
8180
val e2eTest by creating {
8281
defaultSourceSet {
83-
kotlin.srcDir("e2eTest")
82+
kotlin.srcDir("e2eTest/src")
83+
resources.srcDir("e2eTest/test-resources")
84+
dependsOn(sourceSets.getByName("commonMain"))
85+
dependsOn(sourceSets.getByName("jvmMain"))
8486

8587
dependencies {
86-
// Compile against the main compilation's compile classpath and outputs:
87-
implementation(main.compileDependencyFiles + main.runtimeDependencyFiles + main.output.classesDirs)
88-
8988
implementation(kotlin("test"))
9089
implementation(kotlin("test-junit5"))
9190
implementation(project(":aws-runtime:testing"))

services/s3/e2eTest/S3IntegrationTest.kt renamed to services/s3/e2eTest/src/S3IntegrationTest.kt

Lines changed: 60 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,8 @@
44
*/
55
package aws.sdk.kotlin.e2etest
66

7-
import aws.sdk.kotlin.services.s3.S3Client
8-
import aws.sdk.kotlin.services.s3.completeMultipartUpload
9-
import aws.sdk.kotlin.services.s3.createMultipartUpload
10-
import aws.sdk.kotlin.services.s3.listObjects
11-
import aws.sdk.kotlin.services.s3.model.CompletedPart
12-
import aws.sdk.kotlin.services.s3.model.GetObjectRequest
13-
import aws.sdk.kotlin.services.s3.putObject
14-
import aws.sdk.kotlin.services.s3.uploadPart
7+
import aws.sdk.kotlin.services.s3.*
8+
import aws.sdk.kotlin.services.s3.model.*
159
import aws.sdk.kotlin.testing.PRINTABLE_CHARS
1610
import aws.sdk.kotlin.testing.withAllEngines
1711
import aws.smithy.kotlin.runtime.content.ByteStream
@@ -23,13 +17,15 @@ import aws.smithy.kotlin.runtime.hashing.sha256
2317
import aws.smithy.kotlin.runtime.testing.RandomTempFile
2418
import aws.smithy.kotlin.runtime.util.encodeToHex
2519
import kotlinx.coroutines.*
20+
import kotlinx.coroutines.flow.toList
2621
import org.junit.jupiter.api.AfterAll
2722
import org.junit.jupiter.api.BeforeAll
2823
import org.junit.jupiter.api.TestInstance
2924
import java.io.File
3025
import java.util.UUID
3126
import kotlin.test.Test
3227
import kotlin.test.assertEquals
28+
import kotlin.test.assertIs
3329
import kotlin.time.Duration.Companion.seconds
3430
import kotlin.time.ExperimentalTime
3531

@@ -217,6 +213,62 @@ class S3BucketOpsIntegrationTest {
217213
assertEquals(expectedSha256, actualSha256)
218214
}
219215
}
216+
217+
@Test
218+
fun testSelectObjectEventStream(): Unit = runBlocking {
219+
S3Client.fromEnvironment().use { s3 ->
220+
// upload our content to select from
221+
val objKey = "developers.csv"
222+
223+
val content = """
224+
Name,PhoneNumber,City,Occupation
225+
Sam,(949) 555-6701,Irvine,Solutions Architect
226+
Vinod,(949) 555-6702,Los Angeles,Solutions Architect
227+
Jeff,(949) 555-6703,Seattle,AWS Evangelist
228+
Jane,(949) 555-6704,Chicago,Developer
229+
Sean,(949) 555-6705,Indianapolis,Developer
230+
Mary,(949) 555-6706,Detroit,Developer
231+
Kate,(949) 555-6707,Boston,Solutions Architect
232+
""".trimIndent()
233+
234+
s3.putObject {
235+
bucket = testBucket
236+
key = objKey
237+
body = ByteStream.fromString(content)
238+
}
239+
240+
// select content as an event stream
241+
val req = SelectObjectContentRequest {
242+
bucket = testBucket
243+
key = objKey
244+
expressionType = ExpressionType.Sql
245+
expression = """SELECT * FROM s3object s where s."Name" = 'Jane'"""
246+
inputSerialization {
247+
csv {
248+
fileHeaderInfo = FileHeaderInfo.Use
249+
}
250+
compressionType = CompressionType.None
251+
}
252+
outputSerialization {
253+
csv { }
254+
}
255+
}
256+
257+
val events = s3.selectObjectContent(req) { resp ->
258+
// collect flow to list
259+
resp.payload!!.toList()
260+
}
261+
262+
assertEquals(3, events.size)
263+
264+
val records = assertIs<SelectObjectContentEventStream.Records>(events[0])
265+
assertIs<SelectObjectContentEventStream.Stats>(events[1])
266+
assertIs<SelectObjectContentEventStream.End>(events[2])
267+
268+
val expectedRecord = "Jane,(949) 555-6704,Chicago,Developer\n"
269+
assertEquals(expectedRecord, records.value.payload?.decodeToString())
270+
}
271+
}
220272
}
221273

222274
// generate sequence of "chunks" where each range defines the inclusive start and end bytes
File renamed without changes.
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package aws.sdk.kotlin.e2etest
6+
7+
import aws.sdk.kotlin.services.transcribestreaming.TranscribeStreamingClient
8+
import aws.sdk.kotlin.services.transcribestreaming.model.*
9+
import kotlinx.coroutines.Dispatchers
10+
import kotlinx.coroutines.flow.Flow
11+
import kotlinx.coroutines.flow.flow
12+
import kotlinx.coroutines.flow.flowOn
13+
import kotlinx.coroutines.runBlocking
14+
import org.junit.jupiter.api.Test
15+
import org.junit.jupiter.api.TestInstance
16+
import java.io.File
17+
import java.nio.file.Paths
18+
import javax.sound.sampled.AudioSystem
19+
import kotlin.test.assertTrue
20+
21+
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
22+
class TranscribeStreamingIntegrationTest {
23+
24+
@Test
25+
fun testTranscribeEventStream(): Unit = runBlocking {
26+
val url = this::class.java.classLoader.getResource("hello-kotlin-8000.wav") ?: error("failed to load test resource")
27+
val audioFile = Paths.get(url.toURI()).toFile()
28+
29+
TranscribeStreamingClient { region = "us-east-2" }.use { client ->
30+
val transcript = getTranscript(client, audioFile)
31+
assertTrue(transcript.startsWith("Hello from", true), "full transcript: $transcript")
32+
}
33+
}
34+
}
35+
36+
private const val FRAMES_PER_CHUNK = 4096
37+
38+
private fun audioStreamFromFile(file: File): Flow<AudioStream> {
39+
val format = AudioSystem.getAudioFileFormat(file)
40+
val ais = AudioSystem.getAudioInputStream(file)
41+
val bytesPerFrame = ais.format.frameSize
42+
println("audio stream format of $file: $format; bytesPerFrame=$bytesPerFrame")
43+
44+
return flow {
45+
while (true) {
46+
val frameBuffer = ByteArray(FRAMES_PER_CHUNK * bytesPerFrame)
47+
val rc = ais.read(frameBuffer)
48+
if (rc <= 0) {
49+
break
50+
}
51+
52+
val chunk = if (rc < frameBuffer.size) frameBuffer.sliceArray(0 until rc) else frameBuffer
53+
val event = AudioStream.AudioEvent(
54+
AudioEvent {
55+
audioChunk = chunk
56+
},
57+
)
58+
59+
println("emitting event")
60+
emit(event)
61+
}
62+
}.flowOn(Dispatchers.IO)
63+
}
64+
65+
private suspend fun getTranscript(client: TranscribeStreamingClient, audioFile: File): String {
66+
val req = StartStreamTranscriptionRequest {
67+
languageCode = LanguageCode.EnUs
68+
mediaSampleRateHertz = 8000
69+
mediaEncoding = MediaEncoding.Pcm
70+
audioStream = audioStreamFromFile(audioFile)
71+
}
72+
73+
val transcript = client.startStreamTranscription(req) { resp ->
74+
val fullMessage = StringBuilder()
75+
resp.transcriptResultStream?.collect { event ->
76+
when (event) {
77+
is TranscriptResultStream.TranscriptEvent -> {
78+
event.value.transcript?.results?.forEach { result ->
79+
val transcript = result.alternatives?.firstOrNull()?.transcript
80+
println("received TranscriptEvent: isPartial=${result.isPartial}; transcript=$transcript")
81+
if (!result.isPartial) {
82+
transcript?.let { fullMessage.append(it) }
83+
}
84+
}
85+
}
86+
else -> error("unknown event $event")
87+
}
88+
}
89+
fullMessage.toString()
90+
}
91+
92+
return transcript
93+
}

0 commit comments

Comments
 (0)