Skip to content

Commit d1b5cd0

Browse files
committed
grpc-pb: Throw error in decoder instead of hadError() check
Signed-off-by: Johannes Zottele <[email protected]>
1 parent 87b41ad commit d1b5cd0

File tree

12 files changed

+127
-88
lines changed

12 files changed

+127
-88
lines changed

grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/GrpcError.kt

Lines changed: 0 additions & 19 deletions
This file was deleted.
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
3+
*/
4+
5+
package kotlinx.rpc.grpc
6+
7+
public sealed class GrpcException : RuntimeException {
8+
protected constructor(message: String, cause: Throwable? = null) : super(message, cause)
9+
}
10+
11+
12+
public class ProtobufDecodingException : GrpcException {
13+
internal constructor(message: String, cause: Throwable? = null) : super(message, cause)
14+
15+
public companion object Companion {
16+
internal fun missingRequiredField(messageName: String, fieldName: String) =
17+
ProtobufDecodingException("Message '$messageName' is missing a required field: $fieldName")
18+
19+
internal fun negativeSize() = ProtobufDecodingException(
20+
"CodedInputStream encountered an embedded string or message which claimed to have negative size."
21+
)
22+
23+
internal fun invalidTag() = ProtobufDecodingException(
24+
"Protocol message contained an invalid tag (zero)."
25+
)
26+
27+
internal fun truncatedMessage() = ProtobufDecodingException(
28+
("While parsing a protocol message, the input ended unexpectedly "
29+
+ "in the middle of a field. This could mean either that the "
30+
+ "input has been truncated or that an embedded message "
31+
+ "misreported its own length.")
32+
)
33+
34+
internal fun genericParsingError() = ProtobufDecodingException("Failed to parse the message.")
35+
}
36+
}
37+
38+
public class ProtobufEncodingException : GrpcException {
39+
internal constructor(message: String, cause: Throwable? = null) : super(message, cause)
40+
}

grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/internal/readPacked.kt

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
package kotlinx.rpc.grpc.internal
66

7+
import kotlinx.rpc.grpc.ProtobufDecodingException
78
import kotlinx.rpc.grpc.pb.WireDecoder
89

910
internal expect fun WireDecoder.pushLimit(byteLen: Int): Int
@@ -13,20 +14,15 @@ internal expect fun WireDecoder.bytesUntilLimit(): Int
1314
internal inline fun <T : Any> WireDecoder.readPackedVarInternal(
1415
crossinline size: () -> Long,
1516
crossinline readFn: () -> T,
16-
crossinline withError: () -> Unit,
17-
crossinline hadError: () -> Boolean,
1817
): List<T> {
1918
val byteLen = readInt32()
20-
if (hadError()) {
21-
return emptyList()
22-
}
2319
if (byteLen < 0) {
24-
return emptyList<T>().apply { withError() }
20+
throw ProtobufDecodingException.negativeSize()
2521
}
2622
val size = size()
2723
// no size check on jvm
2824
if (size != -1L && size < byteLen) {
29-
return emptyList<T>().apply { withError() }
25+
throw ProtobufDecodingException.truncatedMessage()
3026
}
3127
if (byteLen == 0) {
3228
return emptyList() // actually an empty list (no error)
@@ -38,9 +34,6 @@ internal inline fun <T : Any> WireDecoder.readPackedVarInternal(
3834

3935
while (bytesUntilLimit() > 0) {
4036
val elem = readFn()
41-
if (hadError()) {
42-
break
43-
}
4437
result.add(elem)
4538
}
4639

grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/KTag.kt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,14 @@ internal fun KTag.toRawKTag(): UInt {
3636
return (fieldNr.toUInt() shl KTag.Companion.K_TAG_TYPE_BITS) or wireType.ordinal.toUInt()
3737
}
3838

39-
internal fun KTag.Companion.fromOrNull(rawKTag: UInt): KTag? {
39+
internal fun KTag.Companion.from(rawKTag: UInt): KTag {
4040
val type = (rawKTag and K_TAG_TYPE_MASK).toInt()
4141
val field = (rawKTag shr K_TAG_TYPE_BITS).toInt()
4242
if (!isValidFieldNr(field)) {
43-
return null
43+
error("Invalid field number: $field")
4444
}
4545
if (type >= WireType.entries.size) {
46-
return null
46+
error("Invalid wire type: $type")
4747
}
4848
return KTag(field, WireType.entries[type])
4949
}

grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ public interface WireDecoder : AutoCloseable {
9595
WireType.FIXED32 -> readFixed32()
9696
WireType.FIXED64 -> readFixed64()
9797
WireType.LENGTH_DELIMITED -> readBytes()
98-
WireType.START_GROUP -> error("Unexpected START_GROUP wire type")
98+
WireType.START_GROUP -> error("Unexpected START_GROUP wire type (KRPC-193)")
9999
WireType.END_GROUP -> {} // nothing to do
100100
}
101101
}

grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ import asInternal
1313
import encodeWith
1414
import invoke
1515
import kotlinx.io.Buffer
16-
import kotlinx.rpc.grpc.InvalidProtobufError
16+
import kotlinx.rpc.grpc.ProtobufDecodingException
1717
import kotlinx.rpc.grpc.codec.MessageCodec
1818
import kotlinx.rpc.grpc.test.*
1919
import kotlinx.rpc.grpc.test.common.*
@@ -105,7 +105,7 @@ class ProtosTest {
105105

106106
@Test
107107
fun testRepeatedWithRequiredSubField() {
108-
assertFailsWith<InvalidProtobufError> {
108+
assertFailsWith<ProtobufDecodingException> {
109109
RepeatedWithRequired {
110110
// we construct the message using the internal class,
111111
// so it is not invoking the checkRequired method on construction
@@ -117,7 +117,7 @@ class ProtosTest {
117117
@Test
118118
fun testPresenceCheckProto() {
119119
// Check a missing required field in a user-constructed message
120-
assertFailsWith<InvalidProtobufError> {
120+
assertFailsWith<ProtobufDecodingException> {
121121
PresenceCheck {}
122122
}
123123

@@ -127,7 +127,7 @@ class ProtosTest {
127127
encoder.writeFloat(2, 1f)
128128
encoder.flush()
129129

130-
assertFailsWith<InvalidProtobufError> {
130+
assertFailsWith<ProtobufDecodingException> {
131131
PresenceCheckInternal.CODEC.decode(buffer)
132132
}
133133
}
@@ -246,7 +246,7 @@ class ProtosTest {
246246

247247
@Test
248248
fun testOneOfRequiredSubField() {
249-
assertFailsWith<InvalidProtobufError> {
249+
assertFailsWith<ProtobufDecodingException> {
250250
OneOfWithRequired {
251251
// we construct the message using the internal class,
252252
// so it is not invoking the checkRequired method on construction
@@ -277,7 +277,7 @@ class ProtosTest {
277277

278278
@Test
279279
fun testRecursiveReqNotSet() {
280-
assertFailsWith<InvalidProtobufError> {
280+
assertFailsWith<ProtobufDecodingException> {
281281
val msg = RecursiveReq {
282282
rec = RecursiveReq {
283283
rec = RecursiveReq {
@@ -398,7 +398,7 @@ class ProtosTest {
398398
// we use the internal constructor to avoid a "missing required field" error during object construction
399399
val missingRequiredMessage = PresenceCheckInternal()
400400

401-
assertFailsWith<InvalidProtobufError> {
401+
assertFailsWith<ProtobufDecodingException> {
402402
val msg = TestMap {
403403
messages = mapOf(
404404
2 to missingRequiredMessage

grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/WireCodecTest.kt

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
package kotlinx.rpc.grpc.pb
66

77
import kotlinx.io.Buffer
8+
import kotlinx.rpc.grpc.ProtobufDecodingException
89
import kotlin.test.*
910

1011
enum class TestPlatform {
@@ -808,4 +809,15 @@ class WireCodecTest {
808809
WireDecoder::readPackedEnum
809810
)
810811

812+
813+
@Test
814+
fun testInvalidTag() {
815+
val buffer = Buffer()
816+
buffer.writeByte(0)
817+
818+
assertFailsWith<ProtobufDecodingException> {
819+
WireDecoder(buffer).readTag()
820+
}
821+
}
822+
811823
}

grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.jvm.kt

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55
package kotlinx.rpc.grpc.pb
66

77
import com.google.protobuf.CodedInputStream
8+
import com.google.protobuf.InvalidProtocolBufferException
89
import kotlinx.io.Buffer
910
import kotlinx.io.asInputStream
11+
import kotlinx.rpc.grpc.ProtobufDecodingException
1012
import kotlinx.rpc.grpc.internal.readPackedVarInternal
1113

1214
internal class WireDecoderJvm(source: Buffer) : WireDecoder {
@@ -18,80 +20,79 @@ internal class WireDecoderJvm(source: Buffer) : WireDecoder {
1820
return false
1921
}
2022

21-
override fun readTag(): KTag? {
23+
override fun readTag(): KTag? = checked {
2224
val tag = codedInputStream.readTag().toUInt()
2325
if (tag == 0u) {
2426
return null
2527
}
26-
27-
return KTag.fromOrNull(tag)
28+
return KTag.from(tag)
2829
}
2930

30-
override fun readBool(): Boolean {
31+
override fun readBool(): Boolean = checked {
3132
return codedInputStream.readBool()
3233
}
3334

34-
override fun readInt32(): Int {
35+
override fun readInt32(): Int = checked {
3536
return codedInputStream.readInt32()
3637
}
3738

38-
override fun readInt64(): Long {
39+
override fun readInt64(): Long = checked {
3940
return codedInputStream.readInt64()
4041
}
4142

42-
override fun readUInt32(): UInt {
43+
override fun readUInt32(): UInt = checked {
4344
// todo check java unsigned types
4445
return codedInputStream.readUInt32().toUInt()
4546
}
4647

47-
override fun readUInt64(): ULong {
48+
override fun readUInt64(): ULong = checked {
4849
// todo check java unsigned types
4950
return codedInputStream.readUInt64().toULong()
5051
}
5152

52-
override fun readSInt32(): Int {
53+
override fun readSInt32(): Int = checked {
5354
return codedInputStream.readSInt32()
5455
}
5556

56-
override fun readSInt64(): Long {
57+
override fun readSInt64(): Long = checked {
5758
return codedInputStream.readSInt64()
5859
}
5960

60-
override fun readFixed32(): UInt {
61+
override fun readFixed32(): UInt = checked {
6162
// todo check java unsigned types
6263
return codedInputStream.readFixed32().toUInt()
6364
}
6465

65-
override fun readFixed64(): ULong {
66+
override fun readFixed64(): ULong = checked {
6667
// todo check java unsigned types
6768
return codedInputStream.readFixed64().toULong()
6869
}
6970

70-
override fun readSFixed32(): Int {
71+
override fun readSFixed32(): Int = checked {
7172
return codedInputStream.readSFixed32()
7273
}
7374

74-
override fun readSFixed64(): Long {
75+
override fun readSFixed64(): Long = checked {
7576
return codedInputStream.readSFixed64()
7677
}
7778

78-
override fun readFloat(): Float {
79+
override fun readFloat(): Float = checked {
7980
return codedInputStream.readFloat()
8081
}
8182

82-
override fun readDouble(): Double {
83+
override fun readDouble(): Double = checked {
8384
return codedInputStream.readDouble()
8485
}
8586

86-
override fun readEnum(): Int {
87+
override fun readEnum(): Int = checked {
8788
return codedInputStream.readEnum()
8889
}
8990

90-
override fun readString(): String {
91+
override fun readString(): String = checked {
9192
return codedInputStream.readStringRequireUtf8()
9293
}
9394

94-
override fun readBytes(): ByteArray {
95+
override fun readBytes(): ByteArray = checked {
9596
return codedInputStream.readByteArray()
9697
}
9798

@@ -114,12 +115,25 @@ internal class WireDecoderJvm(source: Buffer) : WireDecoder {
114115

115116
private fun <T : Any> readPackedInternal(read: () -> T) = readPackedVarInternal(
116117
size = { -1 },
117-
readFn = read,
118-
withError = { },
119-
hadError = { false },
118+
readFn = read
120119
)
121120
}
122121

123122
internal actual fun WireDecoder(source: Buffer): WireDecoder {
124123
return WireDecoderJvm(source)
125124
}
125+
126+
/**
127+
* Turns a [InvalidProtocolBufferException] into our own [ProtobufDecodingException].
128+
*/
129+
private inline fun <reified T> checked(block: () -> T): T {
130+
try {
131+
return block()
132+
} catch (e: InvalidProtocolBufferException) {
133+
throw e.toDecodingException()
134+
}
135+
}
136+
137+
private fun InvalidProtocolBufferException.toDecodingException(): ProtobufDecodingException {
138+
return ProtobufDecodingException(message ?: "Failed to decode protobuf message.", cause)
139+
}

0 commit comments

Comments
 (0)