Skip to content

Commit 2e22b7b

Browse files
authored
feat: use HashingSource and HashingSink in event streams (#788)
1 parent 4f44c67 commit 2e22b7b

File tree

4 files changed

+21
-50
lines changed

4 files changed

+21
-50
lines changed

aws-runtime/protocols/aws-event-stream/common/src/aws/sdk/kotlin/runtime/protocol/eventstream/CrcUtil.kt

Lines changed: 0 additions & 34 deletions
This file was deleted.

aws-runtime/protocols/aws-event-stream/common/src/aws/sdk/kotlin/runtime/protocol/eventstream/Message.kt

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
package aws.sdk.kotlin.runtime.protocol.eventstream
77

88
import aws.sdk.kotlin.runtime.InternalSdkApi
9+
import aws.smithy.kotlin.runtime.hashing.Crc32
910
import aws.smithy.kotlin.runtime.io.*
11+
import aws.smithy.kotlin.runtime.util.encodeToHex
1012

1113
internal const val MESSAGE_CRC_BYTE_LEN = 4
1214

@@ -50,14 +52,14 @@ public data class Message(val headers: List<Header>, val payload: ByteArray) {
5052
check(totalLen <= MAX_MESSAGE_SIZE.toUInt()) { "Invalid Message size: $totalLen" }
5153

5254
// Limiting the amount of data read by SdkBufferedSource is tricky and cause incorrect CRC
53-
// if not careful (e.g. creating a buffered source of CrcSource will usually lead to incorrect results
55+
// if not careful (e.g. creating a buffered source of a HashingSource will usually lead to incorrect results
5456
// because the entire point SdkBufferedSource (okio.BufferedSource) is to buffer larger chunks internally
5557
// to optimize short reads)
5658
val messageBuffer = SdkBuffer()
5759
val computedCrc = run {
58-
val crcSource = CrcSource(source)
60+
val crcSource = HashingSource(Crc32(), source)
5961
crcSource.read(messageBuffer, totalLen.toLong() - MESSAGE_CRC_BYTE_LEN.toLong())
60-
crcSource.crc
62+
crcSource.digest()
6163
}
6264

6365
val prelude = Prelude.decode(messageBuffer)
@@ -79,9 +81,9 @@ public data class Message(val headers: List<Header>, val payload: ByteArray) {
7981

8082
message.payload = messageBuffer.readByteArray(prelude.payloadLen.toLong())
8183

82-
val expectedCrc = source.readInt().toUInt()
83-
check(computedCrc == expectedCrc) {
84-
"Message checksum mismatch; expected=0x${expectedCrc.toString(16)}; calculated=0x${computedCrc.toString(16)}"
84+
val expectedCrc = source.readByteArray(4)
85+
check(computedCrc.contentEquals(expectedCrc)) {
86+
"Message checksum mismatch; expected=0x${expectedCrc.encodeToHex()}; calculated=0x${computedCrc.encodeToHex()}"
8587
}
8688
return message.build()
8789
}
@@ -119,15 +121,15 @@ public data class Message(val headers: List<Header>, val payload: ByteArray) {
119121

120122
val prelude = Prelude(messageLen.toInt(), headersLen.toInt())
121123

122-
val sink = CrcSink(dest)
124+
val sink = HashingSink(Crc32(), dest)
123125
val buffer = sink.buffer()
124126

125127
prelude.encode(buffer)
126128
buffer.write(headerBuf, headerBuf.size)
127129
buffer.write(payload)
128130

129131
buffer.emit()
130-
dest.writeInt(sink.crc.toInt())
132+
dest.write(sink.digest())
131133
}
132134
}
133135

aws-runtime/protocols/aws-event-stream/common/src/aws/sdk/kotlin/runtime/protocol/eventstream/Prelude.kt

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
package aws.sdk.kotlin.runtime.protocol.eventstream
77

88
import aws.sdk.kotlin.runtime.InternalSdkApi
9+
import aws.smithy.kotlin.runtime.hashing.Crc32
910
import aws.smithy.kotlin.runtime.io.*
11+
import aws.smithy.kotlin.runtime.util.encodeToHex
1012

1113
internal const val PRELUDE_BYTE_LEN = 8
1214
internal const val PRELUDE_BYTE_LEN_WITH_CRC = PRELUDE_BYTE_LEN + 4
@@ -28,13 +30,13 @@ public data class Prelude(val totalLen: Int, val headersLength: Int) {
2830
* Encode the prelude + CRC to [dest] buffer
2931
*/
3032
public fun encode(dest: SdkBufferedSink) {
31-
val sink = CrcSink(dest)
33+
val sink = HashingSink(Crc32(), dest)
3234
val buffer = sink.buffer()
3335

3436
buffer.writeInt(totalLen)
3537
buffer.writeInt(headersLength)
3638
buffer.emit()
37-
dest.writeInt(sink.crc.toInt())
39+
dest.write(sink.digest())
3840
}
3941

4042
public companion object {
@@ -43,19 +45,20 @@ public data class Prelude(val totalLen: Int, val headersLength: Int) {
4345
*/
4446
public fun decode(source: SdkBufferedSource): Prelude {
4547
check(source.request(PRELUDE_BYTE_LEN_WITH_CRC.toLong())) { "Invalid message prelude" }
46-
val crcSource = CrcSource(source)
48+
val crcSource = HashingSource(Crc32(), source)
4749
val buffer = SdkBuffer()
4850
crcSource.read(buffer, PRELUDE_BYTE_LEN.toLong())
49-
val expectedCrc = source.readInt().toUInt()
50-
val computedCrc = crcSource.crc
51+
52+
val expectedCrc = source.readByteArray(4)
53+
val computedCrc = crcSource.digest()
5154

5255
val totalLen = buffer.readInt()
5356
val headerLen = buffer.readInt()
5457

5558
check(totalLen <= MAX_MESSAGE_SIZE) { "Invalid Message size: $totalLen" }
5659
check(headerLen <= MAX_HEADER_SIZE) { "Invalid Header size: $headerLen" }
57-
check(expectedCrc == computedCrc) {
58-
"Prelude checksum mismatch; expected=0x${expectedCrc.toString(16)}; calculated=0x${computedCrc.toString(16)}"
60+
check(expectedCrc.contentEquals(computedCrc)) {
61+
"Prelude checksum mismatch; expected=0x${expectedCrc.encodeToHex()}; calculated=0x${computedCrc.encodeToHex()}"
5962
}
6063
return Prelude(totalLen, headerLen)
6164
}

aws-runtime/protocols/aws-event-stream/common/test/aws/sdk/kotlin/runtime/protocol/eventstream/MessageTest.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ class MessageTest {
248248
val buffer = sdkBufferOf(encoded)
249249
assertFailsWith<IllegalStateException> {
250250
Message.decode(buffer)
251-
}.message.shouldContain("Message checksum mismatch; expected=0xdeadbeef; calculated=0x1a05860")
251+
}.message.shouldContain("Message checksum mismatch; expected=0xdeadbeef; calculated=0x01a05860")
252252
}
253253

254254
@Test

0 commit comments

Comments
 (0)