Skip to content

Commit fd1fd9f

Browse files
committed
grpc-pb: Enum nullpointer exception state
Signed-off-by: Johannes Zottele <[email protected]>
1 parent 63ae758 commit fd1fd9f

File tree

8 files changed

+152
-15
lines changed

8 files changed

+152
-15
lines changed

grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,29 @@ package kotlinx.rpc.grpc.pb
66

77
import kotlinx.io.Buffer
88
import kotlinx.rpc.grpc.internal.MessageCodec
9+
import kotlinx.rpc.grpc.test.Enum
10+
import kotlinx.rpc.grpc.test.UsingEnum
11+
import kotlinx.rpc.grpc.test.UsingEnumInternal
912
import kotlinx.rpc.grpc.test.common.*
13+
import kotlinx.rpc.grpc.test.invoke
1014
import kotlin.test.Test
1115
import kotlin.test.assertEquals
1216
import kotlin.test.assertFailsWith
1317

1418
class ProtosTest {
1519

20+
@Test
21+
fun bugDemo() {
22+
val buffer = Buffer()
23+
val encoder = WireEncoder(buffer)
24+
encoder.writeEnum(1, 0)
25+
encoder.flush()
26+
27+
val decodedMsg = UsingEnumInternal.CODEC.decode(buffer)
28+
assertEquals(Enum.UNRECOGNIZED(50), decodedMsg.enum)
29+
}
30+
31+
1632
private fun <M> decodeEncode(
1733
msg: M,
1834
codec: MessageCodec<M>
@@ -84,5 +100,40 @@ class ProtosTest {
84100
}
85101
}
86102

103+
@Test
104+
fun testEnumUnrecognized() {
105+
// write unknown enum value
106+
val buffer = Buffer()
107+
val encoder = WireEncoder(buffer)
108+
encoder.writeEnum(1, 50)
109+
encoder.flush()
110+
111+
val decodedMsg = UsingEnumInternal.CODEC.decode(buffer)
112+
assertEquals(Enum.UNRECOGNIZED(50), decodedMsg.enum)
113+
}
114+
115+
@Test
116+
fun testEnumAlias() {
117+
val msg = UsingEnum {
118+
enum = Enum.ONE_SECOND
119+
}
120+
121+
val decodedMsg = decodeEncode(msg, UsingEnumInternal.CODEC)
122+
assertEquals(Enum.ONE, decodedMsg.enum)
123+
assertEquals(Enum.ONE_SECOND, decodedMsg.enum)
124+
}
125+
126+
@Test
127+
fun testDefault() {
128+
// create message without enum field set
129+
val msg = UsingEnum {}
130+
131+
val buffer = UsingEnumInternal.CODEC.encode(msg) as Buffer
132+
// buffer should be empty (default is not in wire)
133+
assertEquals(0, buffer.size)
134+
135+
val decoded = UsingEnumInternal.CODEC.decode(buffer)
136+
assertEquals(Enum.ZERO, decoded.enum)
137+
}
87138

88139
}

grpc/grpc-core/src/commonTest/proto/exclude/enum.proto renamed to grpc/grpc-core/src/commonTest/proto/enum.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@ enum Enum {
1414
message UsingEnum {
1515
Enum enum = 1;
1616
}
17+
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
message OneOfMsg {
2+
3+
oneof f {
4+
int32 i = 2;
5+
fixed64 T = 3;
6+
}
7+
8+
}

protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/CodeGenerator.kt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,10 +102,12 @@ open class CodeGenerator(
102102

103103
internal fun whenBlock(
104104
condition: String? = null,
105+
prefix: String = "",
105106
block: (CodeGenerator.() -> Unit),
106107
) {
108+
val pre = if (prefix.isNotEmpty()) prefix.trim() + " " else ""
107109
val cond = condition?.let { " ($it)" } ?: ""
108-
scope("when$cond", block = block)
110+
scope("${pre}when$cond", block = block)
109111
}
110112

111113
internal fun whenCase(

protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt

Lines changed: 76 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -94,20 +94,30 @@ class ModelToKotlinCommonGenerator(
9494
}
9595

9696
private fun CodeGenerator.generateInternalDeclaredEntities(fileDeclaration: FileDeclaration) {
97-
fileDeclaration.messageDeclarations.forEach { generateInternalMessage(it) }
97+
generateInternalMessageEntities(fileDeclaration.messageDeclarations)
9898

99-
fileDeclaration.messageDeclarations.forEach {
99+
val allEnums =
100+
fileDeclaration.enumDeclarations + fileDeclaration.messageDeclarations.flatMap { it.allEnumsRecursively() }
101+
allEnums.forEach { enum ->
102+
generateInternalEnumConstructor(enum)
103+
}
104+
}
105+
106+
private fun CodeGenerator.generateInternalMessageEntities(messages: List<MessageDeclaration>) {
107+
messages.forEach { generateInternalMessage(it) }
108+
109+
messages.forEach {
100110
generateMessageConstructor(it)
101111
}
102112

103-
fileDeclaration.messageDeclarations.forEach {
113+
messages.forEach {
104114
generateRequiredCheck(it)
105115
generateMessageEncoder(it)
106116
generateMessageDecoder(it)
107117
}
108-
109118
}
110119

120+
111121
private fun MessageDeclaration.fields() = actualFields.map {
112122
it.transformToFieldDeclaration() to it
113123
}
@@ -287,6 +297,11 @@ class ModelToKotlinCommonGenerator(
287297
}
288298
}
289299

300+
is FieldType.Enum -> whenCase("tag.fieldNr == ${field.number} && tag.wireType == $PB_PKG.WireType.VARINT") {
301+
val fromNum = "${fieldType.dec.name.safeFullName()}.fromNumber"
302+
code("$assignment $fromNum(decoder.read$encFuncName())")
303+
}
304+
290305
is FieldType.Map -> TODO()
291306
is FieldType.OneOf -> TODO()
292307
is FieldType.Reference -> TODO()
@@ -338,13 +353,32 @@ class ModelToKotlinCommonGenerator(
338353
"$variable.forEach { encoder.write${fieldType.value.decodeEncodeFuncName()}($number, it) }"
339354
}
340355

356+
is FieldType.Enum -> "encoder.write${type.decodeEncodeFuncName()}(fieldNr = $number, value = $variable.number)"
357+
341358
is FieldType.Map -> TODO()
342359
is FieldType.OneOf -> TODO()
343-
is FieldType.Reference -> TODO()
360+
is FieldType.Reference -> "<TODO: Implement Reference writeValue()>"
344361
}
345362
}
346363

347364

365+
private fun CodeGenerator.generateInternalEnumConstructor(enum: EnumDeclaration) {
366+
function(
367+
"fromNumber",
368+
modifiers = "private",
369+
args = "number: Int",
370+
contextReceiver = "${enum.name.safeFullName()}.Companion",
371+
returnType = enum.name.safeFullName(),
372+
) {
373+
scope("for (entry in entries)") {
374+
ifBranch(condition = "entry.number == number", ifBlock = {
375+
code("return entry")
376+
})
377+
}
378+
code("return ${enum.name.safeFullName()}.UNRECOGNIZED(number)")
379+
}
380+
}
381+
348382
/**
349383
* Generates a function to check for the presence of all required fields in a message declaration.
350384
*/
@@ -383,6 +417,7 @@ class ModelToKotlinCommonGenerator(
383417
else -> error("Unexpected use of size call for field: $name, type: $fieldType")
384418
}
385419

420+
is FieldType.Enum -> sizeFunc
386421
is FieldType.Map -> TODO()
387422
is FieldType.OneOf -> TODO()
388423
is FieldType.Reference -> TODO()
@@ -397,6 +432,9 @@ class ModelToKotlinCommonGenerator(
397432
}
398433

399434
is FieldType.List -> "$name.isNotEmpty()"
435+
is FieldType.Reference -> "<TODO: Implement Reference defaultCheck>"
436+
437+
is FieldType.Enum -> "${fieldType.defaultValue} != $name"
400438

401439
else -> TODO("Field: $name, type: $fieldType")
402440
}
@@ -419,6 +457,7 @@ class ModelToKotlinCommonGenerator(
419457
FieldType.IntegralType.SFIXED32 -> "SFixed32"
420458
FieldType.IntegralType.SFIXED64 -> "SFixed64"
421459
is FieldType.List -> "Packed${value.decodeEncodeFuncName()}"
460+
is FieldType.Enum -> "Enum"
422461
is FieldType.Map -> error("No encoding/decoding function for map types")
423462
is FieldType.OneOf -> error("No encoding/decoding function for oneOf types")
424463
is FieldType.Reference -> error("No encoding/decoding function for sub message types")
@@ -435,6 +474,8 @@ class ModelToKotlinCommonGenerator(
435474
value.safeFullName()
436475
}
437476

477+
is FieldType.Enum -> type.dec.name.safeFullName()
478+
438479
is FieldType.OneOf -> {
439480
val value by type.value
440481
value.safeFullName()
@@ -471,6 +512,7 @@ class ModelToKotlinCommonGenerator(
471512

472513
"Map<${fqKey.safeFullName()}, ${fqValue.safeFullName()}>"
473514
}
515+
474516
}.withNullability(nullable)
475517
}
476518

@@ -497,24 +539,44 @@ class ModelToKotlinCommonGenerator(
497539
}
498540

499541
private fun CodeGenerator.generatePublicEnum(declaration: EnumDeclaration) {
500-
clazz(declaration.name.simpleName, modifiers = "enum") {
501-
declaration.originalEntries.forEach { entry ->
502-
code("${entry.name.simpleName},")
503-
newLine()
542+
543+
val className = declaration.name.simpleName
544+
545+
val entriesSorted = declaration.originalEntries.sortedBy { it.dec.number }
546+
547+
clazz(
548+
className, "sealed",
549+
constructorArgs = listOf("val number: Int"),
550+
) {
551+
552+
declaration.originalEntries.forEach { variant ->
553+
clazz(
554+
name = variant.name.simpleName,
555+
declarationType = DeclarationType.Object,
556+
superTypes = listOf("$className(number = ${variant.dec.number})"),
557+
)
504558
}
505-
code(";")
506-
newLine()
559+
560+
// TODO: Avoid name conflict
561+
clazz(
562+
name = "UNRECOGNIZED",
563+
constructorArgs = listOf("number: Int"),
564+
superTypes = listOf("$className(number)"),
565+
)
507566

508567
if (declaration.aliases.isNotEmpty()) {
509568
newLine()
510569

511570
clazz("", modifiers = "companion", declarationType = DeclarationType.Object) {
512571
declaration.aliases.forEach { alias: EnumDeclaration.Alias ->
513572
code(
514-
"val ${alias.name.simpleName}: ${declaration.name.simpleName} " +
573+
"val ${alias.name.simpleName}: $className " +
515574
"= ${alias.original.name.simpleName}"
516575
)
517576
}
577+
578+
val entryNamesSorted = entriesSorted.joinToString(", ") { it.name.simpleName }
579+
code("val entries: List<$className> = listOf($entryNamesSorted)")
518580
}
519581
}
520582
}
@@ -569,6 +631,8 @@ class ModelToKotlinCommonGenerator(
569631
}
570632
}
571633

634+
private fun MessageDeclaration.allEnumsRecursively(): List<EnumDeclaration> =
635+
enumDeclarations + nestedDeclarations.flatMap(MessageDeclaration::allEnumsRecursively)
572636

573637
private fun String.packageNameSuffixed(suffix: String): String {
574638
return if (isEmpty()) suffix else "$this.$suffix"

protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/codeRequestToModel.kt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,9 @@ private fun Descriptors.GenericDescriptor.fqName(): FqName {
7575
FqName.Declaration(usedName, containingType?.fqName() ?: file.fqName())
7676
}
7777

78-
is Descriptors.EnumValueDescriptor -> FqName.Declaration(name, type.fqName())
7978
is Descriptors.OneofDescriptor -> FqName.Declaration(nameCapital, containingType?.fqName() ?: file.fqName())
79+
is Descriptors.EnumDescriptor -> FqName.Declaration(nameCapital, containingType?.fqName() ?: file.fqName())
80+
is Descriptors.EnumValueDescriptor -> FqName.Declaration(name, type.fqName())
8081
is Descriptors.ServiceDescriptor -> FqName.Declaration(nameCapital, file?.fqName() ?: file.fqName())
8182
is Descriptors.MethodDescriptor -> FqName.Declaration(nameLower, service?.fqName() ?: file.fqName())
8283
else -> error("Unknown generic descriptor: $this")
@@ -233,7 +234,7 @@ private fun Descriptors.FieldDescriptor.modelType(): FieldType {
233234
Descriptors.FieldDescriptor.Type.SFIXED64 -> FieldType.IntegralType.SFIXED64
234235
Descriptors.FieldDescriptor.Type.SINT32 -> FieldType.IntegralType.SINT32
235236
Descriptors.FieldDescriptor.Type.SINT64 -> FieldType.IntegralType.SINT64
236-
Descriptors.FieldDescriptor.Type.ENUM -> FieldType.Reference(lazy { enumType!!.toModel().name })
237+
Descriptors.FieldDescriptor.Type.ENUM -> FieldType.Enum(enumType.toModel())
237238
Descriptors.FieldDescriptor.Type.MESSAGE -> FieldType.Reference(lazy { messageType!!.toModel().name })
238239
Descriptors.FieldDescriptor.Type.GROUP -> error("GROUP type is unsupported")
239240
}

protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/FieldType.kt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@ sealed interface FieldType {
3232
data class Entry(val key: FieldType, val value: FieldType)
3333
}
3434

35+
data class Enum(val dec: EnumDeclaration) : FieldType {
36+
override val defaultValue = dec.defaultEntry().name.fullName()
37+
override val wireType: WireType = WireType.VARINT
38+
}
39+
3540
data class Reference(val value: Lazy<FqName>) : FieldType {
3641
override val defaultValue: String = "null"
3742
override val wireType: WireType = WireType.LENGTH_DELIMITED

protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/model.kt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@ data class EnumDeclaration(
3939
val doc: String?,
4040
val dec: Descriptors.EnumDescriptor,
4141
) {
42+
43+
fun defaultEntry(): Entry {
44+
return originalEntries.minBy { it.dec.number }
45+
}
46+
4247
data class Entry(
4348
val name: FqName,
4449
val doc: String?,

0 commit comments

Comments
 (0)