From 31f322999f93425ee2e5ef007824be2c9fff6c3e Mon Sep 17 00:00:00 2001 From: Johannes Zottele Date: Mon, 4 Aug 2025 12:08:03 +0200 Subject: [PATCH 1/6] grpc-pb: Enum nullpointer exception state Signed-off-by: Johannes Zottele --- .../kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt | 51 +++++++++++ .../commonTest/proto/{exclude => }/enum.proto | 1 + .../src/commonTest/proto/oneof.proto | 8 ++ .../kotlinx/rpc/protobuf/CodeGenerator.kt | 4 +- .../protobuf/ModelToKotlinCommonGenerator.kt | 88 ++++++++++++++++--- .../rpc/protobuf/codeRequestToModel.kt | 5 +- .../kotlinx/rpc/protobuf/model/FieldType.kt | 5 ++ .../kotlinx/rpc/protobuf/model/model.kt | 5 ++ 8 files changed, 152 insertions(+), 15 deletions(-) rename grpc/grpc-core/src/commonTest/proto/{exclude => }/enum.proto (99%) create mode 100644 grpc/grpc-core/src/commonTest/proto/oneof.proto diff --git a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt index 86426cff0..251d9550d 100644 --- a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt +++ b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt @@ -6,13 +6,29 @@ package kotlinx.rpc.grpc.pb import kotlinx.io.Buffer import kotlinx.rpc.grpc.internal.MessageCodec +import kotlinx.rpc.grpc.test.Enum +import kotlinx.rpc.grpc.test.UsingEnum +import kotlinx.rpc.grpc.test.UsingEnumInternal import kotlinx.rpc.grpc.test.common.* +import kotlinx.rpc.grpc.test.invoke import kotlin.test.Test import kotlin.test.assertEquals import kotlin.test.assertFailsWith class ProtosTest { + @Test + fun bugDemo() { + val buffer = Buffer() + val encoder = WireEncoder(buffer) + encoder.writeEnum(1, 0) + encoder.flush() + + val decodedMsg = UsingEnumInternal.CODEC.decode(buffer) + assertEquals(Enum.UNRECOGNIZED(50), decodedMsg.enum) + } + + private fun decodeEncode( msg: M, codec: MessageCodec @@ -84,5 +100,40 @@ class ProtosTest { } } + @Test + fun testEnumUnrecognized() { + // write unknown enum value + val buffer = Buffer() + val encoder = WireEncoder(buffer) + encoder.writeEnum(1, 50) + encoder.flush() + + val decodedMsg = UsingEnumInternal.CODEC.decode(buffer) + assertEquals(Enum.UNRECOGNIZED(50), decodedMsg.enum) + } + + @Test + fun testEnumAlias() { + val msg = UsingEnum { + enum = Enum.ONE_SECOND + } + + val decodedMsg = decodeEncode(msg, UsingEnumInternal.CODEC) + assertEquals(Enum.ONE, decodedMsg.enum) + assertEquals(Enum.ONE_SECOND, decodedMsg.enum) + } + + @Test + fun testDefault() { + // create message without enum field set + val msg = UsingEnum {} + + val buffer = UsingEnumInternal.CODEC.encode(msg) as Buffer + // buffer should be empty (default is not in wire) + assertEquals(0, buffer.size) + + val decoded = UsingEnumInternal.CODEC.decode(buffer) + assertEquals(Enum.ZERO, decoded.enum) + } } \ No newline at end of file diff --git a/grpc/grpc-core/src/commonTest/proto/exclude/enum.proto b/grpc/grpc-core/src/commonTest/proto/enum.proto similarity index 99% rename from grpc/grpc-core/src/commonTest/proto/exclude/enum.proto rename to grpc/grpc-core/src/commonTest/proto/enum.proto index 83db7b7fe..36a447243 100644 --- a/grpc/grpc-core/src/commonTest/proto/exclude/enum.proto +++ b/grpc/grpc-core/src/commonTest/proto/enum.proto @@ -14,3 +14,4 @@ enum Enum { message UsingEnum { Enum enum = 1; } + diff --git a/grpc/grpc-core/src/commonTest/proto/oneof.proto b/grpc/grpc-core/src/commonTest/proto/oneof.proto new file mode 100644 index 000000000..da816cb52 --- /dev/null +++ b/grpc/grpc-core/src/commonTest/proto/oneof.proto @@ -0,0 +1,8 @@ +message OneOfMsg { + + oneof f { + int32 i = 2; + fixed64 T = 3; + } + +} \ No newline at end of file diff --git a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/CodeGenerator.kt b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/CodeGenerator.kt index de3f31e20..ebbadd12f 100644 --- a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/CodeGenerator.kt +++ b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/CodeGenerator.kt @@ -102,10 +102,12 @@ open class CodeGenerator( internal fun whenBlock( condition: String? = null, + prefix: String = "", block: (CodeGenerator.() -> Unit), ) { + val pre = if (prefix.isNotEmpty()) prefix.trim() + " " else "" val cond = condition?.let { " ($it)" } ?: "" - scope("when$cond", block = block) + scope("${pre}when$cond", block = block) } internal fun whenCase( diff --git a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt index 8c79cca77..a257d43e0 100644 --- a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt +++ b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt @@ -94,20 +94,30 @@ class ModelToKotlinCommonGenerator( } private fun CodeGenerator.generateInternalDeclaredEntities(fileDeclaration: FileDeclaration) { - fileDeclaration.messageDeclarations.forEach { generateInternalMessage(it) } + generateInternalMessageEntities(fileDeclaration.messageDeclarations) - fileDeclaration.messageDeclarations.forEach { + val allEnums = + fileDeclaration.enumDeclarations + fileDeclaration.messageDeclarations.flatMap { it.allEnumsRecursively() } + allEnums.forEach { enum -> + generateInternalEnumConstructor(enum) + } + } + + private fun CodeGenerator.generateInternalMessageEntities(messages: List) { + messages.forEach { generateInternalMessage(it) } + + messages.forEach { generateMessageConstructor(it) } - fileDeclaration.messageDeclarations.forEach { + messages.forEach { generateRequiredCheck(it) generateMessageEncoder(it) generateMessageDecoder(it) } - } + private fun MessageDeclaration.fields() = actualFields.map { it.transformToFieldDeclaration() to it } @@ -287,6 +297,11 @@ class ModelToKotlinCommonGenerator( } } + is FieldType.Enum -> whenCase("tag.fieldNr == ${field.number} && tag.wireType == $PB_PKG.WireType.VARINT") { + val fromNum = "${fieldType.dec.name.safeFullName()}.fromNumber" + code("$assignment $fromNum(decoder.read$encFuncName())") + } + is FieldType.Map -> TODO() is FieldType.OneOf -> TODO() is FieldType.Reference -> TODO() @@ -338,13 +353,32 @@ class ModelToKotlinCommonGenerator( "$variable.forEach { encoder.write${fieldType.value.decodeEncodeFuncName()}($number, it) }" } + is FieldType.Enum -> "encoder.write${type.decodeEncodeFuncName()}(fieldNr = $number, value = $variable.number)" + is FieldType.Map -> TODO() is FieldType.OneOf -> TODO() - is FieldType.Reference -> TODO() + is FieldType.Reference -> "" } } + private fun CodeGenerator.generateInternalEnumConstructor(enum: EnumDeclaration) { + function( + "fromNumber", + modifiers = "private", + args = "number: Int", + contextReceiver = "${enum.name.safeFullName()}.Companion", + returnType = enum.name.safeFullName(), + ) { + scope("for (entry in entries)") { + ifBranch(condition = "entry.number == number", ifBlock = { + code("return entry") + }) + } + code("return ${enum.name.safeFullName()}.UNRECOGNIZED(number)") + } + } + /** * Generates a function to check for the presence of all required fields in a message declaration. */ @@ -383,6 +417,7 @@ class ModelToKotlinCommonGenerator( else -> error("Unexpected use of size call for field: $name, type: $fieldType") } + is FieldType.Enum -> sizeFunc is FieldType.Map -> TODO() is FieldType.OneOf -> TODO() is FieldType.Reference -> TODO() @@ -397,6 +432,9 @@ class ModelToKotlinCommonGenerator( } is FieldType.List -> "$name.isNotEmpty()" + is FieldType.Reference -> "" + + is FieldType.Enum -> "${fieldType.defaultValue} != $name" else -> TODO("Field: $name, type: $fieldType") } @@ -419,6 +457,7 @@ class ModelToKotlinCommonGenerator( FieldType.IntegralType.SFIXED32 -> "SFixed32" FieldType.IntegralType.SFIXED64 -> "SFixed64" is FieldType.List -> "Packed${value.decodeEncodeFuncName()}" + is FieldType.Enum -> "Enum" is FieldType.Map -> error("No encoding/decoding function for map types") is FieldType.OneOf -> error("No encoding/decoding function for oneOf types") is FieldType.Reference -> error("No encoding/decoding function for sub message types") @@ -435,6 +474,8 @@ class ModelToKotlinCommonGenerator( value.safeFullName() } + is FieldType.Enum -> type.dec.name.safeFullName() + is FieldType.OneOf -> { val value by type.value value.safeFullName() @@ -471,6 +512,7 @@ class ModelToKotlinCommonGenerator( "Map<${fqKey.safeFullName()}, ${fqValue.safeFullName()}>" } + }.withNullability(nullable) } @@ -497,13 +539,30 @@ class ModelToKotlinCommonGenerator( } private fun CodeGenerator.generatePublicEnum(declaration: EnumDeclaration) { - clazz(declaration.name.simpleName, modifiers = "enum") { - declaration.originalEntries.forEach { entry -> - code("${entry.name.simpleName},") - newLine() + + val className = declaration.name.simpleName + + val entriesSorted = declaration.originalEntries.sortedBy { it.dec.number } + + clazz( + className, "sealed", + constructorArgs = listOf("val number: Int"), + ) { + + declaration.originalEntries.forEach { variant -> + clazz( + name = variant.name.simpleName, + declarationType = DeclarationType.Object, + superTypes = listOf("$className(number = ${variant.dec.number})"), + ) } - code(";") - newLine() + + // TODO: Avoid name conflict + clazz( + name = "UNRECOGNIZED", + constructorArgs = listOf("number: Int"), + superTypes = listOf("$className(number)"), + ) if (declaration.aliases.isNotEmpty()) { newLine() @@ -511,10 +570,13 @@ class ModelToKotlinCommonGenerator( clazz("", modifiers = "companion", declarationType = DeclarationType.Object) { declaration.aliases.forEach { alias: EnumDeclaration.Alias -> code( - "val ${alias.name.simpleName}: ${declaration.name.simpleName} " + + "val ${alias.name.simpleName}: $className " + "= ${alias.original.name.simpleName}" ) } + + val entryNamesSorted = entriesSorted.joinToString(", ") { it.name.simpleName } + code("val entries: List<$className> = listOf($entryNamesSorted)") } } } @@ -569,6 +631,8 @@ class ModelToKotlinCommonGenerator( } } +private fun MessageDeclaration.allEnumsRecursively(): List = + enumDeclarations + nestedDeclarations.flatMap(MessageDeclaration::allEnumsRecursively) private fun String.packageNameSuffixed(suffix: String): String { return if (isEmpty()) suffix else "$this.$suffix" diff --git a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/codeRequestToModel.kt b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/codeRequestToModel.kt index a357d5944..f90b35a3b 100644 --- a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/codeRequestToModel.kt +++ b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/codeRequestToModel.kt @@ -75,8 +75,9 @@ private fun Descriptors.GenericDescriptor.fqName(): FqName { FqName.Declaration(usedName, containingType?.fqName() ?: file.fqName()) } - is Descriptors.EnumValueDescriptor -> FqName.Declaration(name, type.fqName()) is Descriptors.OneofDescriptor -> FqName.Declaration(nameCapital, containingType?.fqName() ?: file.fqName()) + is Descriptors.EnumDescriptor -> FqName.Declaration(nameCapital, containingType?.fqName() ?: file.fqName()) + is Descriptors.EnumValueDescriptor -> FqName.Declaration(name, type.fqName()) is Descriptors.ServiceDescriptor -> FqName.Declaration(nameCapital, file?.fqName() ?: file.fqName()) is Descriptors.MethodDescriptor -> FqName.Declaration(nameLower, service?.fqName() ?: file.fqName()) else -> error("Unknown generic descriptor: $this") @@ -233,7 +234,7 @@ private fun Descriptors.FieldDescriptor.modelType(): FieldType { Descriptors.FieldDescriptor.Type.SFIXED64 -> FieldType.IntegralType.SFIXED64 Descriptors.FieldDescriptor.Type.SINT32 -> FieldType.IntegralType.SINT32 Descriptors.FieldDescriptor.Type.SINT64 -> FieldType.IntegralType.SINT64 - Descriptors.FieldDescriptor.Type.ENUM -> FieldType.Reference(lazy { enumType!!.toModel().name }) + Descriptors.FieldDescriptor.Type.ENUM -> FieldType.Enum(enumType.toModel()) Descriptors.FieldDescriptor.Type.MESSAGE -> FieldType.Reference(lazy { messageType!!.toModel().name }) Descriptors.FieldDescriptor.Type.GROUP -> error("GROUP type is unsupported") } diff --git a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/FieldType.kt b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/FieldType.kt index 50f50846f..6a945b48e 100644 --- a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/FieldType.kt +++ b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/FieldType.kt @@ -32,6 +32,11 @@ sealed interface FieldType { data class Entry(val key: FieldType, val value: FieldType) } + data class Enum(val dec: EnumDeclaration) : FieldType { + override val defaultValue = dec.defaultEntry().name.fullName() + override val wireType: WireType = WireType.VARINT + } + data class Reference(val value: Lazy) : FieldType { override val defaultValue: String = "null" override val wireType: WireType = WireType.LENGTH_DELIMITED diff --git a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/model.kt b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/model.kt index 8eff14079..c477e40e4 100644 --- a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/model.kt +++ b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/model.kt @@ -39,6 +39,11 @@ data class EnumDeclaration( val doc: String?, val dec: Descriptors.EnumDescriptor, ) { + + fun defaultEntry(): Entry { + return originalEntries.minBy { it.dec.number } + } + data class Entry( val name: FqName, val doc: String?, From 5c3c5526db8f64c058780d502a396076b885fec8 Mon Sep 17 00:00:00 2001 From: Johannes Zottele Date: Mon, 4 Aug 2025 14:46:50 +0200 Subject: [PATCH 2/6] grpc-pb: Implement working enum constructor Signed-off-by: Johannes Zottele --- .../kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt | 12 ------------ .../protobuf/ModelToKotlinCommonGenerator.kt | 19 ++++++++++++------- 2 files changed, 12 insertions(+), 19 deletions(-) diff --git a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt index 251d9550d..ed259ab9c 100644 --- a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt +++ b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt @@ -17,18 +17,6 @@ import kotlin.test.assertFailsWith class ProtosTest { - @Test - fun bugDemo() { - val buffer = Buffer() - val encoder = WireEncoder(buffer) - encoder.writeEnum(1, 0) - encoder.flush() - - val decodedMsg = UsingEnumInternal.CODEC.decode(buffer) - assertEquals(Enum.UNRECOGNIZED(50), decodedMsg.enum) - } - - private fun decodeEncode( msg: M, codec: MessageCodec diff --git a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt index a257d43e0..c4efe17e9 100644 --- a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt +++ b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt @@ -370,12 +370,16 @@ class ModelToKotlinCommonGenerator( contextReceiver = "${enum.name.safeFullName()}.Companion", returnType = enum.name.safeFullName(), ) { - scope("for (entry in entries)") { - ifBranch(condition = "entry.number == number", ifBlock = { - code("return entry") - }) + whenBlock(prefix = "return") { + enum.originalEntries.forEach { entry -> + whenCase("number == ${entry.dec.number}") { + code("${entry.name}") + } + } + whenCase("else") { + code("${enum.name.safeFullName()}.UNRECOGNIZED(number)") + } } - code("return ${enum.name.safeFullName()}.UNRECOGNIZED(number)") } } @@ -546,7 +550,7 @@ class ModelToKotlinCommonGenerator( clazz( className, "sealed", - constructorArgs = listOf("val number: Int"), + constructorArgs = listOf("open val number: Int"), ) { declaration.originalEntries.forEach { variant -> @@ -559,8 +563,9 @@ class ModelToKotlinCommonGenerator( // TODO: Avoid name conflict clazz( + modifiers = "data", name = "UNRECOGNIZED", - constructorArgs = listOf("number: Int"), + constructorArgs = listOf("override val number: Int"), superTypes = listOf("$className(number)"), ) From c6cc2062ca7e384f2ab03ff3bda1182e60d62c4d Mon Sep 17 00:00:00 2001 From: Johannes Zottele Date: Mon, 4 Aug 2025 17:35:01 +0200 Subject: [PATCH 3/6] grpc-pb: Support OneOf Signed-off-by: Johannes Zottele --- .../kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt | 42 ++++++++ .../src/commonTest/proto/oneof.proto | 8 +- .../protobuf/ModelToKotlinCommonGenerator.kt | 99 ++++++++++++------- .../rpc/protobuf/codeRequestToModel.kt | 15 ++- .../kotlinx/rpc/protobuf/model/FieldType.kt | 2 +- .../kotlinx/rpc/protobuf/model/model.kt | 6 +- 6 files changed, 127 insertions(+), 45 deletions(-) diff --git a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt index ed259ab9c..e369084d1 100644 --- a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt +++ b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt @@ -4,6 +4,9 @@ package kotlinx.rpc.grpc.pb +import OneOfMsg +import OneOfMsgInternal +import invoke import kotlinx.io.Buffer import kotlinx.rpc.grpc.internal.MessageCodec import kotlinx.rpc.grpc.test.Enum @@ -14,6 +17,7 @@ import kotlinx.rpc.grpc.test.invoke import kotlin.test.Test import kotlin.test.assertEquals import kotlin.test.assertFailsWith +import kotlin.test.assertNull class ProtosTest { @@ -124,4 +128,42 @@ class ProtosTest { assertEquals(Enum.ZERO, decoded.enum) } + @Test + fun testOneOf() { + val msg1 = OneOfMsg { + field = OneOfMsg.Field.Sint(23) + } + val decoded1 = decodeEncode(msg1, OneOfMsgInternal.CODEC) + assertEquals(OneOfMsg.Field.Sint(23), decoded1.field) + + val msg2 = OneOfMsg { + field = OneOfMsg.Field.Fixed(21u) + } + val decoded2 = decodeEncode(msg2, OneOfMsgInternal.CODEC) + assertEquals(OneOfMsg.Field.Fixed(21u), decoded2.field) + } + + @Test + fun testOneOfLastWins() { + // write two values on the oneOf field. + // the second value must be the one stored during decoding. + val buffer = Buffer() + val encoder = WireEncoder(buffer) + encoder.writeInt32(2, 99) + encoder.writeFixed64(3, 123u) + encoder.flush() + + val decoded = OneOfMsgInternal.CODEC.decode(buffer) + assertEquals(OneOfMsg.Field.Fixed(123u), decoded.field) + } + + @Test + fun testOneOfNull() { + // write two values on the oneOf field. + // the second value must be the one stored during decoding. + val buffer = Buffer() + val decoded = OneOfMsgInternal.CODEC.decode(buffer) + assertNull(decoded.field) + } + } \ No newline at end of file diff --git a/grpc/grpc-core/src/commonTest/proto/oneof.proto b/grpc/grpc-core/src/commonTest/proto/oneof.proto index da816cb52..bde4b3883 100644 --- a/grpc/grpc-core/src/commonTest/proto/oneof.proto +++ b/grpc/grpc-core/src/commonTest/proto/oneof.proto @@ -1,8 +1,6 @@ message OneOfMsg { - - oneof f { - int32 i = 2; - fixed64 T = 3; + oneof field { + int32 sint = 2; + fixed64 fixed = 3; } - } \ No newline at end of file diff --git a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt index c4efe17e9..886e7fa8f 100644 --- a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt +++ b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt @@ -78,6 +78,9 @@ class ModelToKotlinCommonGenerator( generateInternalDeclaredEntities(this@generateInternalKotlinFile) + import("kotlinx.rpc.internal.utils.*") + import("kotlinx.coroutines.flow.*") + additionalInternalImports.forEach { import(it) } @@ -279,12 +282,15 @@ class ModelToKotlinCommonGenerator( code("return msg") } - private fun CodeGenerator.readMatchCase(field: FieldDeclaration) { - val encFuncName = field.type.decodeEncodeFuncName() - val assignment = "msg.${field.name} =" + private fun CodeGenerator.readMatchCase( + field: FieldDeclaration, + assignment: String = "msg.${field.name} =", + wrapperCtor: (String) -> String = { it } + ) { when (val fieldType = field.type) { is FieldType.IntegralType -> whenCase("tag.fieldNr == ${field.number} && tag.wireType == $PB_PKG.WireType.${field.type.wireType.name}") { - code("$assignment decoder.read$encFuncName()") + val raw = "decoder.read${field.type.decodeEncodeFuncName()}()" + code("$assignment ${wrapperCtor(raw)}") } is FieldType.List -> if (field.dec.isPacked) { @@ -299,11 +305,22 @@ class ModelToKotlinCommonGenerator( is FieldType.Enum -> whenCase("tag.fieldNr == ${field.number} && tag.wireType == $PB_PKG.WireType.VARINT") { val fromNum = "${fieldType.dec.name.safeFullName()}.fromNumber" - code("$assignment $fromNum(decoder.read$encFuncName())") + val raw = "$fromNum(decoder.read${field.type.decodeEncodeFuncName()}())" + code("$assignment ${wrapperCtor(raw)}") + } + + is FieldType.OneOf -> { + fieldType.dec.variants.forEach { variant -> + val variantName = "${fieldType.dec.name.safeFullName()}.${variant.name}" + readMatchCase( + field = variant, + assignment = assignment, + wrapperCtor = { "$variantName($it)" } + ) + } } is FieldType.Map -> TODO() - is FieldType.OneOf -> TODO() is FieldType.Reference -> TODO() } } @@ -323,42 +340,54 @@ class ModelToKotlinCommonGenerator( val fieldName = field.name if (field.nullable) { scope("$fieldName?.also") { - code(field.writeValue("it")) + writeFieldValue(field, "it") } } else if (!field.dec.hasPresence()) { ifBranch(condition = field.defaultCheck(), ifBlock = { - code(field.writeValue(field.name)) + writeFieldValue(field, field.name) }) } else { - code(field.writeValue(field.name)) + writeFieldValue(field, field.name) } } } - private fun FieldDeclaration.writeValue(variable: String): String { - return when (val fieldType = type) { - is FieldType.IntegralType -> "encoder.write${type.decodeEncodeFuncName()}(fieldNr = $number, value = $variable)" - is FieldType.List -> when { - dec.isPacked && packedFixedSize -> - "encoder.writePacked${fieldType.value.decodeEncodeFuncName()}(fieldNr = $number, value = $variable)" + private fun CodeGenerator.writeFieldValue(field: FieldDeclaration, valueVar: String) { + var encFunc = field.type.decodeEncodeFuncName() + val number = field.number + when (val fieldType = field.type) { + is FieldType.IntegralType -> code("encoder.write${encFunc!!}(fieldNr = $number, value = $valueVar)") + is FieldType.List -> { + encFunc = fieldType.value.decodeEncodeFuncName() + when { + field.dec.isPacked && field.packedFixedSize -> + code("encoder.writePacked${encFunc!!}(fieldNr = $number, value = $valueVar)") - dec.isPacked && !packedFixedSize -> - "encoder.writePacked${fieldType.value.decodeEncodeFuncName()}(fieldNr = $number, value = $variable, fieldSize = ${ - wireSizeCall( - variable + field.dec.isPacked && !field.packedFixedSize -> + code( + "encoder.writePacked${encFunc!!}(fieldNr = $number, value = $valueVar, fieldSize = ${ + field.wireSizeCall(valueVar) + })" ) - })" - else -> - "$variable.forEach { encoder.write${fieldType.value.decodeEncodeFuncName()}($number, it) }" + else -> code("$valueVar.forEach { encoder.write${encFunc!!}($number, it) }") + } } - is FieldType.Enum -> "encoder.write${type.decodeEncodeFuncName()}(fieldNr = $number, value = $variable.number)" + is FieldType.Enum -> code("encoder.write${encFunc!!}(fieldNr = $number, value = ${valueVar}.number)") + + is FieldType.OneOf -> whenBlock("val value = $valueVar") { + fieldType.dec.variants.forEach { variant -> + whenCase("is ${fieldType.dec.name.safeFullName()}.${variant.name}") { + writeFieldValue(variant, "value.value") + } + } + } is FieldType.Map -> TODO() - is FieldType.OneOf -> TODO() - is FieldType.Reference -> "" + is FieldType.Reference -> code("") } + } @@ -370,9 +399,9 @@ class ModelToKotlinCommonGenerator( contextReceiver = "${enum.name.safeFullName()}.Companion", returnType = enum.name.safeFullName(), ) { - whenBlock(prefix = "return") { + whenBlock(prefix = "return", condition = "number") { enum.originalEntries.forEach { entry -> - whenCase("number == ${entry.dec.number}") { + whenCase("${entry.dec.number}") { code("${entry.name}") } } @@ -408,7 +437,8 @@ class ModelToKotlinCommonGenerator( private fun FieldDeclaration.wireSizeCall(variable: String): String { - val sizeFunc = "$PB_PKG.WireSize.${type.decodeEncodeFuncName().replaceFirstChar { it.lowercase() }}($variable)" + val sizeFunc = + "$PB_PKG.WireSize.${type.decodeEncodeFuncName()!!.replaceFirstChar { it.lowercase() }}($variable)" return when (val fieldType = type) { is FieldType.IntegralType -> when { fieldType.wireType == WireType.FIXED32 -> "32" @@ -444,7 +474,7 @@ class ModelToKotlinCommonGenerator( } } - private fun FieldType.decodeEncodeFuncName(): String = when (this) { + private fun FieldType.decodeEncodeFuncName(): String? = when (this) { FieldType.IntegralType.STRING -> "String" FieldType.IntegralType.BYTES -> "Bytes" FieldType.IntegralType.BOOL -> "Bool" @@ -462,9 +492,9 @@ class ModelToKotlinCommonGenerator( FieldType.IntegralType.SFIXED64 -> "SFixed64" is FieldType.List -> "Packed${value.decodeEncodeFuncName()}" is FieldType.Enum -> "Enum" - is FieldType.Map -> error("No encoding/decoding function for map types") - is FieldType.OneOf -> error("No encoding/decoding function for oneOf types") - is FieldType.Reference -> error("No encoding/decoding function for sub message types") + is FieldType.Map -> null + is FieldType.OneOf -> null + is FieldType.Reference -> null } private fun FieldDeclaration.transformToFieldDeclaration(): String { @@ -480,10 +510,7 @@ class ModelToKotlinCommonGenerator( is FieldType.Enum -> type.dec.name.safeFullName() - is FieldType.OneOf -> { - val value by type.value - value.safeFullName() - } + is FieldType.OneOf -> type.dec.name.safeFullName() is FieldType.IntegralType -> { type.fqName.simpleName diff --git a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/codeRequestToModel.kt b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/codeRequestToModel.kt index f90b35a3b..aaa9b48db 100644 --- a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/codeRequestToModel.kt +++ b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/codeRequestToModel.kt @@ -113,20 +113,31 @@ private fun Descriptors.FileDescriptor.toModel(): FileDeclaration = cached { private fun Descriptors.Descriptor.toModel(): MessageDeclaration = cached { var currPresenceIdx = 0 - val regularFields = fields + var regularFields = fields // only fields that are not part of a oneOf declaration .filter { field -> field.realContainingOneof == null } .map { val presenceIdx = if (it.hasPresence()) currPresenceIdx++ else null it.toModel(presenceIdx = presenceIdx) } + val oneOfs = oneofs.filter { it.fields[0].realContainingOneof != null }.map { it.toModel() } + + regularFields = regularFields + oneOfs.map { + FieldDeclaration( + // TODO: Proper handling of this field name + it.name.simpleName.lowercase(), + FieldType.OneOf(it), + doc = null, + dec = it.variants.first().dec, + ) + } return MessageDeclaration( name = fqName(), presenceMaskSize = currPresenceIdx, actualFields = regularFields, // get all oneof declarations that are not created from an optional in proto3 https://github.com/googleapis/api-linter/issues/1323 - oneOfDeclarations = oneofs.filter { it.fields[0].realContainingOneof != null }.map { it.toModel() }, + oneOfDeclarations = oneOfs, enumDeclarations = enumTypes.map { it.toModel() }, nestedDeclarations = nestedTypes.map { it.toModel() }, doc = null, diff --git a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/FieldType.kt b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/FieldType.kt index 6a945b48e..76c64925d 100644 --- a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/FieldType.kt +++ b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/FieldType.kt @@ -42,7 +42,7 @@ sealed interface FieldType { override val wireType: WireType = WireType.LENGTH_DELIMITED } - data class OneOf(val value: Lazy, val index: Int) : FieldType { + data class OneOf(val dec: OneOfDeclaration) : FieldType { override val defaultValue: String = "null" override val wireType: WireType = WireType.LENGTH_DELIMITED } diff --git a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/model.kt b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/model.kt index c477e40e4..275f2c104 100644 --- a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/model.kt +++ b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/model.kt @@ -76,7 +76,11 @@ data class FieldDeclaration( val packedFixedSize = type.wireType == WireType.FIXED64 || type.wireType == WireType.FIXED32 // aligns with edition settings and backward compatibility with proto2 and proto3 - val nullable: Boolean = dec.hasPresence() && !dec.isRequired && !dec.hasDefaultValue() && !dec.isRepeated + val nullable: Boolean = (dec.hasPresence() && !dec.isRequired && !dec.hasDefaultValue() + && !dec.isRepeated // repeated fields cannot be nullable (just empty) + && dec.realContainingOneof == null // upper conditions would match oneof inner fields + ) + || type is FieldType.OneOf // all OneOf fields are nullable val number: Int = dec.number } From d78a69674fde632bc164ee823c294f941b0b653b Mon Sep 17 00:00:00 2001 From: Johannes Zottele Date: Mon, 4 Aug 2025 20:05:24 +0200 Subject: [PATCH 4/6] grpc-pb: Fix imports Signed-off-by: Johannes Zottele --- grpc/grpc-core/build.gradle.kts | 8 ++++++++ .../kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt | 1 + 2 files changed, 9 insertions(+) diff --git a/grpc/grpc-core/build.gradle.kts b/grpc/grpc-core/build.gradle.kts index 7700643b1..623e1ed7e 100644 --- a/grpc/grpc-core/build.gradle.kts +++ b/grpc/grpc-core/build.gradle.kts @@ -7,6 +7,7 @@ import kotlinx.rpc.proto.kotlinMultiplatform import org.gradle.internal.extensions.stdlib.capitalized import org.jetbrains.kotlin.gradle.plugin.mpp.KotlinNativeTarget import org.jetbrains.kotlin.gradle.tasks.CInteropProcess +import org.jetbrains.kotlin.gradle.tasks.KotlinCompile plugins { alias(libs.plugins.conventions.kmp) @@ -159,5 +160,12 @@ rpc { dependsOn(gradle.includedBuild("protoc-gen").task(":jar")) } } + + // generate protos before compiling tests + project.tasks.withType().configureEach { + if (name.startsWith("compileTest")) { + dependsOn(project.tasks.withType()) + } + } } } diff --git a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt index 886e7fa8f..8142c43ca 100644 --- a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt +++ b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt @@ -78,6 +78,7 @@ class ModelToKotlinCommonGenerator( generateInternalDeclaredEntities(this@generateInternalKotlinFile) + import("kotlinx.rpc.grpc.pb.*") import("kotlinx.rpc.internal.utils.*") import("kotlinx.coroutines.flow.*") From 498d20aefcf9d7714637631090e67fa5afaac061 Mon Sep 17 00:00:00 2001 From: Johannes Zottele Date: Tue, 5 Aug 2025 08:54:50 +0200 Subject: [PATCH 5/6] grpc-pb: Fix lowercase name of oneof field Signed-off-by: Johannes Zottele --- .../kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt | 2 +- .../kotlin/kotlinx/rpc/protobuf/codeRequestToModel.kt | 2 +- protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/utils.kt | 9 +++++++++ 3 files changed, 11 insertions(+), 2 deletions(-) create mode 100644 protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/utils.kt diff --git a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt index 8142c43ca..b3739778a 100644 --- a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt +++ b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt @@ -439,7 +439,7 @@ class ModelToKotlinCommonGenerator( private fun FieldDeclaration.wireSizeCall(variable: String): String { val sizeFunc = - "$PB_PKG.WireSize.${type.decodeEncodeFuncName()!!.replaceFirstChar { it.lowercase() }}($variable)" + "$PB_PKG.WireSize.${type.decodeEncodeFuncName()!!.decapitalize()}($variable)" return when (val fieldType = type) { is FieldType.IntegralType -> when { fieldType.wireType == WireType.FIXED32 -> "32" diff --git a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/codeRequestToModel.kt b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/codeRequestToModel.kt index aaa9b48db..da10cdac8 100644 --- a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/codeRequestToModel.kt +++ b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/codeRequestToModel.kt @@ -125,7 +125,7 @@ private fun Descriptors.Descriptor.toModel(): MessageDeclaration = cached { regularFields = regularFields + oneOfs.map { FieldDeclaration( // TODO: Proper handling of this field name - it.name.simpleName.lowercase(), + it.name.simpleName.decapitalize(), FieldType.OneOf(it), doc = null, dec = it.variants.first().dec, diff --git a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/utils.kt b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/utils.kt new file mode 100644 index 000000000..2b44fdcac --- /dev/null +++ b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/utils.kt @@ -0,0 +1,9 @@ +/* + * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.rpc.protobuf + +internal fun String.decapitalize(): String { + return this.replaceFirstChar { it.lowercase() } +} \ No newline at end of file From 8f1a7dce14188ee2521272da9497347490e3d333 Mon Sep 17 00:00:00 2001 From: Johannes Zottele Date: Tue, 5 Aug 2025 14:05:18 +0200 Subject: [PATCH 6/6] grpc-pb: Address PR comments Signed-off-by: Johannes Zottele --- grpc/grpc-core/src/commonTest/proto/enum.proto | 1 - .../kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt | 8 +++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/grpc/grpc-core/src/commonTest/proto/enum.proto b/grpc/grpc-core/src/commonTest/proto/enum.proto index 36a447243..83db7b7fe 100644 --- a/grpc/grpc-core/src/commonTest/proto/enum.proto +++ b/grpc/grpc-core/src/commonTest/proto/enum.proto @@ -14,4 +14,3 @@ enum Enum { message UsingEnum { Enum enum = 1; } - diff --git a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt index b3739778a..74ad5fe59 100644 --- a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt +++ b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt @@ -438,8 +438,10 @@ class ModelToKotlinCommonGenerator( private fun FieldDeclaration.wireSizeCall(variable: String): String { + val sizeFunName = + type.decodeEncodeFuncName()?.decapitalize() ?: error("No decodeEncodeFuncName for type: $type") val sizeFunc = - "$PB_PKG.WireSize.${type.decodeEncodeFuncName()!!.decapitalize()}($variable)" + "$PB_PKG.WireSize.$sizeFunName($variable)" return when (val fieldType = type) { is FieldType.IntegralType -> when { fieldType.wireType == WireType.FIXED32 -> "32" @@ -604,12 +606,12 @@ class ModelToKotlinCommonGenerator( declaration.aliases.forEach { alias: EnumDeclaration.Alias -> code( "val ${alias.name.simpleName}: $className " + - "= ${alias.original.name.simpleName}" + "get() = ${alias.original.name.simpleName}" ) } val entryNamesSorted = entriesSorted.joinToString(", ") { it.name.simpleName } - code("val entries: List<$className> = listOf($entryNamesSorted)") + code("val entries: Lazy> = lazy { listOf($entryNamesSorted) }") } } }