Skip to content

Commit a404252

Browse files
authored
grpc-pb: Add gRPC Exceptions and skip unknown fields (#437)
* grpc-pb: Skip unknown fields Signed-off-by: Johannes Zottele <[email protected]> * grpc-pb: Add InvalidProtobufError class Signed-off-by: Johannes Zottele <[email protected]> * grpc-pb: Throw error in decoder instead of hadError() check Signed-off-by: Johannes Zottele <[email protected]> * grpc-pb: Fix string encoding bug Signed-off-by: Johannes Zottele <[email protected]> * grpc-pb: Remove hadError() method Signed-off-by: Johannes Zottele <[email protected]> * grpc-pb: Throw exception instead of Boolean when encoding value Signed-off-by: Johannes Zottele <[email protected]> * grpc-pb: Move JVM exception check to CODEC to avoid performance overhead Signed-off-by: Johannes Zottele <[email protected]> * grpc-pb: Address PR comments Signed-off-by: Johannes Zottele <[email protected]> * grpc-pb: Address PR Comments Signed-off-by: Johannes Zottele <[email protected]> --------- Signed-off-by: Johannes Zottele <[email protected]>
1 parent 93c246c commit a404252

File tree

15 files changed

+344
-258
lines changed

15 files changed

+344
-258
lines changed

grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/internal/readPacked.kt

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
package kotlinx.rpc.grpc.internal
66

7+
import kotlinx.rpc.grpc.pb.ProtobufDecodingException
78
import kotlinx.rpc.grpc.pb.WireDecoder
89

910
internal expect fun WireDecoder.pushLimit(byteLen: Int): Int
@@ -13,20 +14,15 @@ internal expect fun WireDecoder.bytesUntilLimit(): Int
1314
internal inline fun <T : Any> WireDecoder.readPackedVarInternal(
1415
crossinline size: () -> Long,
1516
crossinline readFn: () -> T,
16-
crossinline withError: () -> Unit,
17-
crossinline hadError: () -> Boolean,
1817
): List<T> {
1918
val byteLen = readInt32()
20-
if (hadError()) {
21-
return emptyList()
22-
}
2319
if (byteLen < 0) {
24-
return emptyList<T>().apply { withError() }
20+
throw ProtobufDecodingException.negativeSize()
2521
}
2622
val size = size()
2723
// no size check on jvm
2824
if (size != -1L && size < byteLen) {
29-
return emptyList<T>().apply { withError() }
25+
throw ProtobufDecodingException.truncatedMessage()
3026
}
3127
if (byteLen == 0) {
3228
return emptyList() // actually an empty list (no error)
@@ -38,9 +34,6 @@ internal inline fun <T : Any> WireDecoder.readPackedVarInternal(
3834

3935
while (bytesUntilLimit() > 0) {
4036
val elem = readFn()
41-
if (hadError()) {
42-
break
43-
}
4437
result.add(elem)
4538
}
4639

grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/KTag.kt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,14 @@ internal fun KTag.toRawKTag(): UInt {
3636
return (fieldNr.toUInt() shl KTag.Companion.K_TAG_TYPE_BITS) or wireType.ordinal.toUInt()
3737
}
3838

39-
internal fun KTag.Companion.fromOrNull(rawKTag: UInt): KTag? {
39+
internal fun KTag.Companion.from(rawKTag: UInt): KTag {
4040
val type = (rawKTag and K_TAG_TYPE_MASK).toInt()
4141
val field = (rawKTag shr K_TAG_TYPE_BITS).toInt()
4242
if (!isValidFieldNr(field)) {
43-
return null
43+
throw ProtobufDecodingException("Invalid field number: $field")
4444
}
4545
if (type >= WireType.entries.size) {
46-
return null
46+
throw ProtobufDecodingException("Invalid wire type: $type")
4747
}
4848
return KTag(field, WireType.entries[type])
4949
}
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
3+
*/
4+
5+
package kotlinx.rpc.grpc.pb
6+
7+
public sealed class ProtobufException : RuntimeException {
8+
protected constructor(message: String, cause: Throwable? = null) : super(message, cause)
9+
}
10+
11+
12+
public class ProtobufDecodingException : ProtobufException {
13+
public constructor(message: String, cause: Throwable? = null) : super(message, cause)
14+
15+
public companion object Companion {
16+
internal fun missingRequiredField(messageName: String, fieldName: String) =
17+
ProtobufDecodingException("Message '$messageName' is missing a required field: $fieldName")
18+
19+
internal fun negativeSize() = ProtobufDecodingException(
20+
"Decoder encountered an embedded string or message which claimed to have negative size."
21+
)
22+
23+
internal fun invalidTag() = ProtobufDecodingException(
24+
"Protocol message contained an invalid tag (zero)."
25+
)
26+
27+
internal fun truncatedMessage() = ProtobufDecodingException(
28+
("While parsing a protocol message, the input ended unexpectedly "
29+
+ "in the middle of a field. This could mean either that the "
30+
+ "input has been truncated or that an embedded message "
31+
+ "misreported its own length.")
32+
)
33+
34+
internal fun genericParsingError() = ProtobufDecodingException("Failed to parse the message.")
35+
}
36+
}
37+
38+
public class ProtobufEncodingException : ProtobufException {
39+
public constructor(message: String, cause: Throwable? = null) : super(message, cause)
40+
}

grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.kt

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,12 @@ internal const val MAX_PACKED_BULK_SIZE: Int = 1_000_000
1919
* This decoder is used by first calling [readTag], than looking up the field based on the field number in the returned,
2020
* tag and then calling the actual `read*()` method to read the value to the corresponding field.
2121
*
22-
* [hadError] indicates an error during decoding. While calling `read*()` is safe, the returned values
23-
* are meaningless if [hadError] returns `true`.
22+
* All `read*()` methods will throw an exception if the expected value couldn't be decoded.
23+
* Because of optimization reasons, the exception is platform-dependent. To unify them
24+
* wrap the decoding in a [checkForPlatformDecodeException] call, which turn platform-specific exceptions
25+
* into a [ProtobufDecodingException].
2426
*
25-
* NOTE: If the [hadError] after a call to `read*()` returns `false`, it doesn't mean that the
27+
* NOTE: If a call to `read*()` doesn't throw an error, it doesn't mean that the
2628
* value is correctly decoded. E.g., the following test will pass:
2729
* ```kt
2830
* val fieldNr = 1
@@ -32,17 +34,17 @@ internal const val MAX_PACKED_BULK_SIZE: Int = 1_000_000
3234
* assertTrue(encoder.writeInt32(fieldNr, 12312))
3335
* encoder.flush()
3436
*
35-
* WireDecoder(buffer).use { decoder ->
36-
* decoder.readTag()
37-
* decoder.readBool()
38-
* assertFalse(decoder.hasError())
37+
* checkForPlatformDecodeException {
38+
* WireDecoder(buffer).use { decoder ->
39+
* decoder.readTag()
40+
* decoder.readBool()
41+
* assertFalse(decoder.hasError())
42+
* }
3943
* }
4044
* ```
4145
*/
4246
@InternalRpcApi
4347
public interface WireDecoder : AutoCloseable {
44-
public fun hadError(): Boolean
45-
4648
/**
4749
* When the read tag is null, it indicates EOF and the parser may stop at this point.
4850
*/
@@ -79,17 +81,32 @@ public interface WireDecoder : AutoCloseable {
7981
public fun readPackedDouble(): List<Double>
8082
public fun readPackedEnum(): List<Int>
8183

82-
// TODO: Throw error instead of just returning
8384
public fun <T : InternalMessage> readMessage(msg: T, decoder: (T, WireDecoder) -> Unit) {
8485
val len = readInt32()
85-
if (hadError()) return
86-
if (len <= 0) return
86+
if (len < 0) throw ProtobufDecodingException.negativeSize()
8787
val limit = pushLimit(len)
8888
decoder(msg, this)
8989
popLimit(limit)
9090
}
91+
92+
public fun skipValue(writeType: WireType) {
93+
when (writeType) {
94+
WireType.VARINT -> readInt64()
95+
WireType.FIXED32 -> readFixed32()
96+
WireType.FIXED64 -> readFixed64()
97+
WireType.LENGTH_DELIMITED -> readBytes()
98+
WireType.START_GROUP -> throw ProtobufDecodingException("Unexpected START_GROUP wire type (KRPC-193)")
99+
WireType.END_GROUP -> {} // nothing to do
100+
}
101+
}
91102
}
92103

104+
/**
105+
* Turns exceptions thrown by different platforms during decoding into [ProtobufDecodingException].
106+
*/
107+
@InternalRpcApi
108+
public expect inline fun checkForPlatformDecodeException(block: () -> Unit)
109+
93110
/**
94111
* Creates a platform-specific [WireDecoder].
95112
*
@@ -100,4 +117,5 @@ public interface WireDecoder : AutoCloseable {
100117
*
101118
* @param source The buffer containing the encoded wire-format data.
102119
*/
103-
internal expect fun WireDecoder(source: Buffer): WireDecoder
120+
@InternalRpcApi
121+
public expect fun WireDecoder(source: Buffer): WireDecoder

grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.kt

Lines changed: 43 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -10,54 +10,63 @@ import kotlinx.rpc.internal.utils.InternalRpcApi
1010
/**
1111
* A platform-specific class that encodes values into protobuf's wire format.
1212
*
13-
* If one `write*()` method returns false, the encoding of the value failed
14-
* and no further encodings can be performed on this [WireEncoder].
13+
* If one `write*()` method fails to encode the value in the buffer,
14+
* it will throw a platform-specific exception.
15+
*
16+
* Wrap the encoding of a message with [checkForPlatformEncodeException] to
17+
* turn all thrown platform-specific exceptions into [ProtobufEncodingException]s.
1518
*
1619
* [flush] must be called to ensure that all data is written to the [Sink].
1720
*/
1821
@InternalRpcApi
1922
@OptIn(ExperimentalUnsignedTypes::class)
2023
public interface WireEncoder {
2124
public fun flush()
22-
public fun writeBool(fieldNr: Int, value: Boolean): Boolean
23-
public fun writeInt32(fieldNr: Int, value: Int): Boolean
24-
public fun writeInt64(fieldNr: Int, value: Long): Boolean
25-
public fun writeUInt32(fieldNr: Int, value: UInt): Boolean
26-
public fun writeUInt64(fieldNr: Int, value: ULong): Boolean
27-
public fun writeSInt32(fieldNr: Int, value: Int): Boolean
28-
public fun writeSInt64(fieldNr: Int, value: Long): Boolean
29-
public fun writeFixed32(fieldNr: Int, value: UInt): Boolean
30-
public fun writeFixed64(fieldNr: Int, value: ULong): Boolean
31-
public fun writeSFixed32(fieldNr: Int, value: Int): Boolean
32-
public fun writeSFixed64(fieldNr: Int, value: Long): Boolean
33-
public fun writeFloat(fieldNr: Int, value: Float): Boolean
34-
public fun writeDouble(fieldNr: Int, value: Double): Boolean
35-
public fun writeEnum(fieldNr: Int, value: Int): Boolean
36-
public fun writeBytes(fieldNr: Int, value: ByteArray): Boolean
37-
public fun writeString(fieldNr: Int, value: String): Boolean
38-
public fun writePackedBool(fieldNr: Int, value: List<Boolean>, fieldSize: Int): Boolean
39-
public fun writePackedInt32(fieldNr: Int, value: List<Int>, fieldSize: Int): Boolean
40-
public fun writePackedInt64(fieldNr: Int, value: List<Long>, fieldSize: Int): Boolean
41-
public fun writePackedUInt32(fieldNr: Int, value: List<UInt>, fieldSize: Int): Boolean
42-
public fun writePackedUInt64(fieldNr: Int, value: List<ULong>, fieldSize: Int): Boolean
43-
public fun writePackedSInt32(fieldNr: Int, value: List<Int>, fieldSize: Int): Boolean
44-
public fun writePackedSInt64(fieldNr: Int, value: List<Long>, fieldSize: Int): Boolean
45-
public fun writePackedFixed32(fieldNr: Int, value: List<UInt>): Boolean
46-
public fun writePackedFixed64(fieldNr: Int, value: List<ULong>): Boolean
47-
public fun writePackedSFixed32(fieldNr: Int, value: List<Int>): Boolean
48-
public fun writePackedSFixed64(fieldNr: Int, value: List<Long>): Boolean
49-
public fun writePackedFloat(fieldNr: Int, value: List<Float>): Boolean
50-
public fun writePackedDouble(fieldNr: Int, value: List<Double>): Boolean
51-
public fun writePackedEnum(fieldNr: Int, value: List<Int>, fieldSize: Int): Boolean =
25+
public fun writeBool(fieldNr: Int, value: Boolean)
26+
public fun writeInt32(fieldNr: Int, value: Int)
27+
public fun writeInt64(fieldNr: Int, value: Long)
28+
public fun writeUInt32(fieldNr: Int, value: UInt)
29+
public fun writeUInt64(fieldNr: Int, value: ULong)
30+
public fun writeSInt32(fieldNr: Int, value: Int)
31+
public fun writeSInt64(fieldNr: Int, value: Long)
32+
public fun writeFixed32(fieldNr: Int, value: UInt)
33+
public fun writeFixed64(fieldNr: Int, value: ULong)
34+
public fun writeSFixed32(fieldNr: Int, value: Int)
35+
public fun writeSFixed64(fieldNr: Int, value: Long)
36+
public fun writeFloat(fieldNr: Int, value: Float)
37+
public fun writeDouble(fieldNr: Int, value: Double)
38+
public fun writeEnum(fieldNr: Int, value: Int)
39+
public fun writeBytes(fieldNr: Int, value: ByteArray)
40+
public fun writeString(fieldNr: Int, value: String)
41+
public fun writePackedBool(fieldNr: Int, value: List<Boolean>, fieldSize: Int)
42+
public fun writePackedInt32(fieldNr: Int, value: List<Int>, fieldSize: Int)
43+
public fun writePackedInt64(fieldNr: Int, value: List<Long>, fieldSize: Int)
44+
public fun writePackedUInt32(fieldNr: Int, value: List<UInt>, fieldSize: Int)
45+
public fun writePackedUInt64(fieldNr: Int, value: List<ULong>, fieldSize: Int)
46+
public fun writePackedSInt32(fieldNr: Int, value: List<Int>, fieldSize: Int)
47+
public fun writePackedSInt64(fieldNr: Int, value: List<Long>, fieldSize: Int)
48+
public fun writePackedFixed32(fieldNr: Int, value: List<UInt>)
49+
public fun writePackedFixed64(fieldNr: Int, value: List<ULong>)
50+
public fun writePackedSFixed32(fieldNr: Int, value: List<Int>)
51+
public fun writePackedSFixed64(fieldNr: Int, value: List<Long>)
52+
public fun writePackedFloat(fieldNr: Int, value: List<Float>)
53+
public fun writePackedDouble(fieldNr: Int, value: List<Double>)
54+
public fun writePackedEnum(fieldNr: Int, value: List<Int>, fieldSize: Int): Unit =
5255
writePackedInt32(fieldNr, value, fieldSize)
5356

5457
public fun <T : InternalMessage> writeMessage(
5558
fieldNr: Int,
5659
value: T,
57-
encode: T.(WireEncoder) -> Unit
60+
encode: T.(WireEncoder) -> Unit,
5861
)
5962

6063
}
6164

65+
/**
66+
* Turns exceptions thrown by different platforms during encoding into [ProtobufEncodingException].
67+
*/
68+
@InternalRpcApi
69+
public expect inline fun checkForPlatformEncodeException(block: () -> Unit)
6270

63-
internal expect fun WireEncoder(sink: Sink): WireEncoder
71+
@InternalRpcApi
72+
public expect fun WireEncoder(sink: Sink): WireEncoder

grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,24 @@ class ProtosTest {
3434
return codec.decode(source)
3535
}
3636

37+
@Test
38+
fun testUnknownFieldsDontCrash() {
39+
val buffer = Buffer()
40+
val encoder = WireEncoder(buffer)
41+
// optional sint32 sint32 = 7
42+
encoder.writeSInt32(7, 12)
43+
// optional sint64 sint64 = 8; (unknown as wrong wire-type)
44+
encoder.writeFloat(8, 2f)
45+
// optional fixed32 fixed32 = 9;
46+
encoder.writeFixed32(9, 1234u)
47+
encoder.flush()
48+
49+
val decoded = AllPrimitivesInternal.CODEC.decode(buffer)
50+
assertEquals(12, decoded.sint32)
51+
assertNull(decoded.sint64)
52+
assertEquals(1234u, decoded.fixed32)
53+
}
54+
3755
@Test
3856
fun testAllPrimitiveProto() {
3957
val msg = AllPrimitives {
@@ -86,7 +104,7 @@ class ProtosTest {
86104

87105
@Test
88106
fun testRepeatedWithRequiredSubField() {
89-
assertFailsWith<IllegalStateException> {
107+
assertFailsWith<ProtobufDecodingException> {
90108
RepeatedWithRequired {
91109
// we construct the message using the internal class,
92110
// so it is not invoking the checkRequired method on construction
@@ -98,7 +116,7 @@ class ProtosTest {
98116
@Test
99117
fun testPresenceCheckProto() {
100118
// Check a missing required field in a user-constructed message
101-
assertFailsWith<IllegalStateException> {
119+
assertFailsWith<ProtobufDecodingException> {
102120
PresenceCheck {}
103121
}
104122

@@ -108,7 +126,7 @@ class ProtosTest {
108126
encoder.writeFloat(2, 1f)
109127
encoder.flush()
110128

111-
assertFailsWith<IllegalStateException> {
129+
assertFailsWith<ProtobufDecodingException> {
112130
PresenceCheckInternal.CODEC.decode(buffer)
113131
}
114132
}
@@ -227,7 +245,7 @@ class ProtosTest {
227245

228246
@Test
229247
fun testOneOfRequiredSubField() {
230-
assertFailsWith<IllegalStateException> {
248+
assertFailsWith<ProtobufDecodingException> {
231249
OneOfWithRequired {
232250
// we construct the message using the internal class,
233251
// so it is not invoking the checkRequired method on construction
@@ -258,7 +276,7 @@ class ProtosTest {
258276

259277
@Test
260278
fun testRecursiveReqNotSet() {
261-
assertFailsWith<IllegalStateException> {
279+
assertFailsWith<ProtobufDecodingException> {
262280
val msg = RecursiveReq {
263281
rec = RecursiveReq {
264282
rec = RecursiveReq {
@@ -379,7 +397,7 @@ class ProtosTest {
379397
// we use the internal constructor to avoid a "missing required field" error during object construction
380398
val missingRequiredMessage = PresenceCheckInternal()
381399

382-
assertFailsWith<IllegalStateException> {
400+
assertFailsWith<ProtobufDecodingException> {
383401
val msg = TestMap {
384402
messages = mapOf(
385403
2 to missingRequiredMessage

0 commit comments

Comments
 (0)