From 9964bafb7b0a3a45b34871c4f6a8f2409b5c4908 Mon Sep 17 00:00:00 2001 From: Johannes Zottele Date: Tue, 5 Aug 2025 12:19:33 +0200 Subject: [PATCH 01/11] grpc-pb: Implement sub messages Signed-off-by: Johannes Zottele --- .../kotlinx/rpc/grpc/pb/InternalMessage.kt | 1 + .../kotlin/kotlinx/rpc/grpc/pb/WireDecoder.kt | 12 ++ .../kotlin/kotlinx/rpc/grpc/pb/WireEncoder.kt | 7 + .../kotlin/kotlinx/rpc/grpc/pb/WireSize.kt | 46 ++++ .../src/commonTest/proto/submsg.proto | 8 + .../kotlinx/rpc/grpc/pb/WireEncoder.jvm.kt | 11 +- .../kotlinx/rpc/grpc/pb/WireEncoder.native.kt | 10 + .../protobuf/ModelToKotlinCommonGenerator.kt | 199 +++++++++++++----- .../rpc/protobuf/codeRequestToModel.kt | 2 +- .../kotlinx/rpc/protobuf/model/FieldType.kt | 6 +- 10 files changed, 247 insertions(+), 55 deletions(-) create mode 100644 grpc/grpc-core/src/commonTest/proto/submsg.proto diff --git a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/InternalMessage.kt b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/InternalMessage.kt index f9d07f017..370f95f58 100644 --- a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/InternalMessage.kt +++ b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/InternalMessage.kt @@ -10,4 +10,5 @@ import kotlinx.rpc.internal.utils.InternalRpcApi @InternalRpcApi public abstract class InternalMessage(fieldsWithPresence: Int) { public val presenceMask: BitSet = BitSet(fieldsWithPresence) + public abstract val _size: Int } \ No newline at end of file diff --git a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.kt b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.kt index 167ad9d1e..168017cc7 100644 --- a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.kt +++ b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.kt @@ -5,6 +5,8 @@ package kotlinx.rpc.grpc.pb import kotlinx.io.Buffer +import kotlinx.rpc.grpc.internal.popLimit +import kotlinx.rpc.grpc.internal.pushLimit import kotlinx.rpc.internal.utils.InternalRpcApi // TODO: Evaluate if this buffer size is suitable for all targets (KRPC-186) @@ -76,6 +78,16 @@ public interface WireDecoder : AutoCloseable { public fun readPackedFloat(): List public fun readPackedDouble(): List public fun readPackedEnum(): List + + // TODO: Throw error instead of just returning + public fun readMessage(msg: T, decoder: (T, WireDecoder) -> Unit) { + val len = readInt32() + if (hadError()) return + if (len <= 0) return + val limit = pushLimit(len) + decoder(msg, this) + popLimit(limit) + } } /** diff --git a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.kt b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.kt index 028040823..69fd70cb0 100644 --- a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.kt +++ b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.kt @@ -50,6 +50,13 @@ public interface WireEncoder { public fun writePackedDouble(fieldNr: Int, value: List): Boolean public fun writePackedEnum(fieldNr: Int, value: List, fieldSize: Int): Boolean = writePackedInt32(fieldNr, value, fieldSize) + + public fun writeMessage( + fieldNr: Int, + value: T, + encode: (WireEncoder) -> Unit + ) + } diff --git a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/WireSize.kt b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/WireSize.kt index 50f957c1b..dbd2b5bc5 100644 --- a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/WireSize.kt +++ b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/WireSize.kt @@ -4,11 +4,15 @@ package kotlinx.rpc.grpc.pb +import kotlinx.io.bytestring.encodeToByteString import kotlinx.rpc.internal.utils.InternalRpcApi @InternalRpcApi public object WireSize +public fun WireSize.tag(fieldNumber: Int, wireType: WireType): Int = + uInt32(KTag(fieldNumber, wireType).toRawKTag()) + @InternalRpcApi public expect fun WireSize.int32(value: Int): Int @@ -27,12 +31,36 @@ public expect fun WireSize.sInt32(value: Int): Int @InternalRpcApi public expect fun WireSize.sInt64(value: Long): Int +@InternalRpcApi +public fun WireSize.float(value: Float): Int = 32 + +@InternalRpcApi +public fun WireSize.double(value: Double): Int = 64 + +@InternalRpcApi +public fun WireSize.fixed32(value: UInt): Int = 32 + +@InternalRpcApi +public fun WireSize.fixed64(value: ULong): Int = 64 + +@InternalRpcApi +public fun WireSize.sFixed32(value: Int): Int = 32 + +@InternalRpcApi +public fun WireSize.sFixed64(value: Long): Int = 64 + @InternalRpcApi public fun WireSize.bool(value: Boolean): Int = int32(if (value) 1 else 0) @InternalRpcApi public fun WireSize.enum(value: Int): Int = int32(value) +@InternalRpcApi +public fun WireSize.bytes(value: ByteArray): Int = value.size + +@InternalRpcApi +public fun WireSize.string(value: String): Int = value.encodeToByteString().size + @InternalRpcApi public fun WireSize.packedInt32(value: List): Int = value.sumOf { int32(it) } @@ -53,3 +81,21 @@ public fun WireSize.packedSInt64(value: List): Int = value.sumOf { sInt64( @InternalRpcApi public fun WireSize.packedEnum(value: List): Int = value.sumOf { enum(it) } + +@InternalRpcApi +public fun WireSize.packedFloat(value: List): Int = value.size * 32 + +@InternalRpcApi +public fun WireSize.packedDouble(value: List): Int = value.size * 64 + +@InternalRpcApi +public fun WireSize.packedFixed32(value: List): Int = value.size * 32 + +@InternalRpcApi +public fun WireSize.packedFixed64(value: List): Int = value.size * 64 + +@InternalRpcApi +public fun WireSize.packedSFixed32(value: List): Int = value.size * 32 + +@InternalRpcApi +public fun WireSize.packedSFixed64(value: List): Int = value.size * 64 diff --git a/grpc/grpc-core/src/commonTest/proto/submsg.proto b/grpc/grpc-core/src/commonTest/proto/submsg.proto new file mode 100644 index 000000000..76b4b51a8 --- /dev/null +++ b/grpc/grpc-core/src/commonTest/proto/submsg.proto @@ -0,0 +1,8 @@ +message Outer { + + required Inner inner = 1; + + message Inner { + required int32 field = 1; + } +} \ No newline at end of file diff --git a/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.jvm.kt b/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.jvm.kt index 00fb1f869..6fea0c0af 100644 --- a/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.jvm.kt +++ b/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.jvm.kt @@ -4,7 +4,6 @@ package kotlinx.rpc.grpc.pb -import com.google.protobuf.ByteString import com.google.protobuf.CodedOutputStream import kotlinx.io.Sink import kotlinx.io.asOutputStream @@ -201,6 +200,16 @@ private class WireEncoderJvm(sink: Sink) : WireEncoder { return true } + override fun writeMessage( + fieldNr: Int, + value: T, + encode: (WireEncoder) -> Unit + ) { + codedOutputStream.writeTag(fieldNr, WireType.LENGTH_DELIMITED.ordinal) + codedOutputStream.writeInt32NoTag(value._size) + encode(this) + } + private inline fun writePackedInternal( fieldNr: Int, value: List, diff --git a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.native.kt b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.native.kt index 261a8b281..249867f6c 100644 --- a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.native.kt +++ b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.native.kt @@ -162,6 +162,16 @@ internal class WireEncoderNative(private val sink: Sink) : WireEncoder { override fun writePackedDouble(fieldNr: Int, value: List) = writePackedInternal(fieldNr, value, value.size * Double.SIZE_BYTES, ::pw_encoder_write_double_no_tag) + + override fun writeMessage( + fieldNr: Int, + value: T, + encode: (WireEncoder) -> Unit + ) { + pw_encoder_write_tag(raw, fieldNr, WireType.LENGTH_DELIMITED.ordinal) + pw_encoder_write_int32_no_tag(raw, value._size) + encode(this) + } } internal actual fun WireEncoder(sink: Sink): WireEncoder = WireEncoderNative(sink) diff --git a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt index a37f07b47..af6a3ab3b 100644 --- a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt +++ b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt @@ -111,14 +111,17 @@ class ModelToKotlinCommonGenerator( private fun CodeGenerator.generateInternalMessageEntities(messages: List) { messages.forEach { generateInternalMessage(it) } - messages.forEach { + // emit all required functions in the outer scope + val allMsgs = messages + messages.flatMap(MessageDeclaration::nestedDeclarations) + allMsgs.forEach { generateMessageConstructor(it) } - - messages.forEach { + allMsgs.forEach { generateRequiredCheck(it) generateMessageEncoder(it) generateMessageDecoder(it) + generateInternalComputeSize(it) + generateInternalCastExtension(it) } } @@ -171,16 +174,15 @@ class ModelToKotlinCommonGenerator( generatePresenceIndicesObject(declaration) + code("override val _size: Int by lazy { computeSize() }") + declaration.fields().forEach { (fieldDeclaration, field) -> val value = when { field.nullable -> { "= null" } - field.type is FieldType.Reference -> { - additionalInternalImports.add("kotlin.properties.Delegates") - "by Delegates.notNull()" - } + field.type is FieldType.Reference -> "= ${field.type.dec.internalClassFullName()}()" else -> { "= ${field.type.defaultValue}" @@ -227,17 +229,19 @@ class ModelToKotlinCommonGenerator( val bufferFqName = "kotlinx.io.Buffer" scope("object CODEC : kotlinx.rpc.grpc.codec.MessageCodec<$msgFqName>") { function("encode", modifiers = "override", args = "value: $msgFqName", returnType = sourceFqName) { - code("val msg = value as? ${declaration.internalClassFullName()} ?: error { \"$downCastErrorStr\" }") code("val buffer = $bufferFqName()") code("val encoder = $PB_PKG.WireEncoder(buffer)") - code("msg.encodeWith(encoder)") + code("value.asInternal().encodeWith(encoder)") code("encoder.flush()") code("return buffer") } function("decode", modifiers = "override", args = "stream: $sourceFqName", returnType = msgFqName) { scope("$PB_PKG.WireDecoder(stream as $bufferFqName).use") { - code("return ${declaration.internalClassFullName()}.CODEC.decodeWith(it)") + code("val msg = ${declaration.internalClassFullName()}()") + code("${declaration.internalClassFullName()}.CODEC.decodeWith(msg, it)") + code("msg.checkRequiredFields()") + code("return msg") } } } @@ -258,12 +262,10 @@ class ModelToKotlinCommonGenerator( private fun CodeGenerator.generateMessageDecoder(declaration: MessageDeclaration) = function( name = "decodeWith", - modifiers = "private", - args = "decoder: $PB_PKG.WireDecoder", - contextReceiver = "${declaration.internalClassFullName()}.CODEC", - returnType = declaration.internalClassName() + modifiers = "internal", + args = "msg: ${declaration.internalClassFullName()}, decoder: $PB_PKG.WireDecoder", + contextReceiver = "${declaration.internalClassFullName()}.CODEC" ) { - code("val msg = ${declaration.internalClassFullName()}()") whileBlock("!decoder.hadError()") { code("val tag = decoder.readTag() ?: break // EOF, we read the whole message") whenBlock { @@ -276,26 +278,23 @@ class ModelToKotlinCommonGenerator( ifBlock = { code("error(\"Error during decoding of ${declaration.name.simpleName}\")") } ) - code("msg.checkRequiredFields()") - // TODO: Make a lists immutable - code("return msg") } private fun CodeGenerator.readMatchCase( field: FieldDeclaration, - assignment: String = "msg.${field.name} =", + lvalue: String = "msg.${field.name}", wrapperCtor: (String) -> String = { it } ) { when (val fieldType = field.type) { is FieldType.IntegralType -> whenCase("tag.fieldNr == ${field.number} && tag.wireType == $PB_PKG.WireType.${field.type.wireType.name}") { val raw = "decoder.read${field.type.decodeEncodeFuncName()}()" - code("$assignment ${wrapperCtor(raw)}") + code("$lvalue = ${wrapperCtor(raw)}") } is FieldType.List -> if (field.dec.isPacked) { whenCase("tag.fieldNr == ${field.number} && tag.wireType == $PB_PKG.WireType.LENGTH_DELIMITED") { - code("$assignment decoder.readPacked${fieldType.value.decodeEncodeFuncName()}()") + code("$lvalue = decoder.readPacked${fieldType.value.decodeEncodeFuncName()}()") } } else { whenCase("tag.fieldNr == ${field.number} && tag.wireType == $PB_PKG.WireType.${fieldType.value.wireType.name}") { @@ -306,7 +305,7 @@ class ModelToKotlinCommonGenerator( is FieldType.Enum -> whenCase("tag.fieldNr == ${field.number} && tag.wireType == $PB_PKG.WireType.VARINT") { val fromNum = "${fieldType.dec.name.safeFullName()}.fromNumber" val raw = "$fromNum(decoder.read${field.type.decodeEncodeFuncName()}())" - code("$assignment ${wrapperCtor(raw)}") + code("$lvalue = ${wrapperCtor(raw)}") } is FieldType.OneOf -> { @@ -314,20 +313,26 @@ class ModelToKotlinCommonGenerator( val variantName = "${fieldType.dec.name.safeFullName()}.${variant.name}" readMatchCase( field = variant, - assignment = assignment, + lvalue = lvalue, wrapperCtor = { "$variantName($it)" } ) } } + is FieldType.Reference -> { + val internalClassName = fieldType.dec.internalClassFullName() + whenCase("tag.fieldNr == ${field.number} && tag.wireType == $PB_PKG.WireType.LENGTH_DELIMITED") { + code("decoder.readMessage($lvalue.asInternal(), $internalClassName::decodeWith)") + } + } + is FieldType.Map -> TODO() - is FieldType.Reference -> TODO() } } private fun CodeGenerator.generateMessageEncoder(declaration: MessageDeclaration) = function( name = "encodeWith", - modifiers = "private", + modifiers = "internal", args = "encoder: $PB_PKG.WireEncoder", contextReceiver = declaration.internalClassFullName(), ) { @@ -343,7 +348,7 @@ class ModelToKotlinCommonGenerator( writeFieldValue(field, "it") } } else if (!field.dec.hasPresence()) { - ifBranch(condition = field.defaultCheck(), ifBlock = { + ifBranch(condition = field.notDefaultCheck(), ifBlock = { writeFieldValue(field, field.name) }) } else { @@ -366,7 +371,7 @@ class ModelToKotlinCommonGenerator( field.dec.isPacked && !field.packedFixedSize -> code( "encoder.writePacked${encFunc!!}(fieldNr = $number, value = $valueVar, fieldSize = ${ - field.wireSizeCall(valueVar) + field.valueSizeCall(valueVar) })" ) @@ -385,7 +390,8 @@ class ModelToKotlinCommonGenerator( } is FieldType.Map -> TODO() - is FieldType.Reference -> code("") + + is FieldType.Reference -> code("encoder.writeMessage(fieldNr = ${field.number}, value = $valueVar.asInternal()) { encodeWith(it) }") } } @@ -435,32 +441,123 @@ class ModelToKotlinCommonGenerator( } } + private fun CodeGenerator.generateInternalComputeSize(declaration: MessageDeclaration) { + function( + name = "computeSize", + modifiers = "private", + contextReceiver = declaration.internalClassFullName(), + returnType = "Int", + ) { + code("var result = 0") + declaration.actualFields.forEach { field -> + val fieldName = field.name + if (field.nullable) { + scope("$fieldName?.also") { + generateFieldComputeSizeCall(field, "it") + } + } else if (!field.dec.hasPresence()) { + scope("if (${field.notDefaultCheck()})") { + generateFieldComputeSizeCall(field, fieldName) + } + } else { + generateFieldComputeSizeCall(field, fieldName) + } + } + code("return result") + } + } - private fun FieldDeclaration.wireSizeCall(variable: String): String { - val sizeFunName = - type.decodeEncodeFuncName()?.decapitalize() ?: error("No decodeEncodeFuncName for type: $type") - val sizeFunc = - "$PB_PKG.WireSize.$sizeFunName($variable)" - return when (val fieldType = type) { - is FieldType.IntegralType -> when { - fieldType.wireType == WireType.FIXED32 -> "32" - fieldType.wireType == WireType.FIXED64 -> "64" - else -> sizeFunc + private fun CodeGenerator.generateInternalCastExtension(declaration: MessageDeclaration) { + function( + "asInternal", + modifiers = "private", + contextReceiver = declaration.name.safeFullName(), + returnType = declaration.internalClassFullName(), + ) { + code("return this as? ${declaration.internalClassFullName()} ?: error(\"Message \${this::class.simpleName} is a non-internal message type.\")") + } + } + + + private fun CodeGenerator.generateFieldComputeSizeCall(field: FieldDeclaration, variable: String) { + val valueSize by lazy { field.valueSizeCall(variable) } + val tagSize = tagSizeCall(field.number, field.type.wireType) + + when (field.type) { + is FieldType.List -> when { + // packed fields also have the tag + len + field.dec.isPacked -> code("result += $valueSize.let { $tagSize + ${int32SizeCall("it")} + it }") + else -> code("result = $valueSize") + } + + is FieldType.Reference, + FieldType.IntegralType.STRING, + FieldType.IntegralType.BYTES -> code("$valueSize.let { $tagSize + ${int32SizeCall("it")} + it }") + + is FieldType.Map -> TODO() + is FieldType.OneOf -> whenBlock("val value = $variable") { + field.type.dec.variants.forEach { variant -> + val variantName = "${field.type.dec.name.safeFullName()}.${variant.name}" + whenCase("is $variantName") { + generateFieldComputeSizeCall(variant, "value.value") + } + } } + is FieldType.Enum, + FieldType.IntegralType.BOOL, + FieldType.IntegralType.FLOAT, + FieldType.IntegralType.DOUBLE, + FieldType.IntegralType.INT32, + FieldType.IntegralType.INT64, + FieldType.IntegralType.UINT32, + FieldType.IntegralType.UINT64, + FieldType.IntegralType.FIXED32, + FieldType.IntegralType.FIXED64, + FieldType.IntegralType.SINT32, + FieldType.IntegralType.SINT64, + FieldType.IntegralType.SFIXED32, + FieldType.IntegralType.SFIXED64 -> code("result += ($tagSize + $valueSize)") + } + } + + private fun FieldDeclaration.valueSizeCall(variable: String): String { + val sizeFunName = type.decodeEncodeFuncName()?.decapitalize() + val sizeFunc = "$PB_PKG.WireSize.$sizeFunName($variable)" + + return when (type) { + is FieldType.IntegralType -> sizeFunc + is FieldType.List -> when { - dec.isPacked && !packedFixedSize -> sizeFunc - else -> error("Unexpected use of size call for field: $name, type: $fieldType") + dec.isPacked -> sizeFunc + else -> { + // calculate the size of the values within the list. + val valueTypeSizeFunc = type.value.decodeEncodeFuncName()?.decapitalize() + "$variable.sumOf { $PB_PKG.WireSize.$valueTypeSizeFunc(it) + ${ + tagSizeCall( + number, + type.value.wireType + ) + } }" + } } - is FieldType.Enum -> sizeFunc + is FieldType.Enum -> "$PB_PKG.WireSize.$sizeFunName($variable.number)" is FieldType.Map -> TODO() - is FieldType.OneOf -> TODO() - is FieldType.Reference -> TODO() + is FieldType.OneOf -> error("OneOf fields have no direct valueSizeCall") + is FieldType.Reference -> "$variable.asInternal()._size" } } - private fun FieldDeclaration.defaultCheck(): String { + private fun tagSizeCall(number: Int, wireType: WireType): String { + return "$PB_PKG.WireSize.tag($number, $PB_PKG.WireType.$wireType)" + } + + private fun int32SizeCall(number: String): String { + return "$PB_PKG.WireSize.int32($number)" + } + + private fun FieldDeclaration.notDefaultCheck(): String { return when (val fieldType = type) { is FieldType.IntegralType -> when (fieldType) { FieldType.IntegralType.BYTES, FieldType.IntegralType.STRING -> "$name.isNotEmpty()" @@ -506,8 +603,7 @@ class ModelToKotlinCommonGenerator( private fun FieldDeclaration.typeFqName(): String { return when (type) { is FieldType.Reference -> { - val value by type.value - value.safeFullName() + type.dec.name.safeFullName() } is FieldType.Enum -> type.dec.name.safeFullName() @@ -520,7 +616,7 @@ class ModelToKotlinCommonGenerator( is FieldType.List -> { val fqValue = when (val value = type.value) { - is FieldType.Reference -> value.value.value + is FieldType.Reference -> value.dec.name is FieldType.IntegralType -> value.fqName else -> error("Unsupported type: $value") } @@ -532,13 +628,13 @@ class ModelToKotlinCommonGenerator( val entry by type.entry val fqKey = when (val key = entry.key) { - is FieldType.Reference -> key.value.value + is FieldType.Reference -> key.dec.name is FieldType.IntegralType -> key.fqName else -> error("Unsupported type: $key") } val fqValue = when (val value = entry.value) { - is FieldType.Reference -> value.value.value + is FieldType.Reference -> value.dec.name is FieldType.IntegralType -> value.fqName else -> error("Unsupported type: $value") } @@ -610,7 +706,7 @@ class ModelToKotlinCommonGenerator( } val entryNamesSorted = entriesSorted.joinToString(", ") { it.name.simpleName } - code("val entries: Lazy> = lazy { listOf($entryNamesSorted) }") + code("val entries: List<$className> by lazy { listOf($entryNamesSorted) }") } } } @@ -668,6 +764,9 @@ class ModelToKotlinCommonGenerator( private fun MessageDeclaration.allEnumsRecursively(): List = enumDeclarations + nestedDeclarations.flatMap(MessageDeclaration::allEnumsRecursively) +private fun MessageDeclaration.allNestedRecursively(): List = + nestedDeclarations + nestedDeclarations.flatMap(MessageDeclaration::allNestedRecursively) + private fun String.packageNameSuffixed(suffix: String): String { return if (isEmpty()) suffix else "$this.$suffix" } diff --git a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/codeRequestToModel.kt b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/codeRequestToModel.kt index da10cdac8..157c01aff 100644 --- a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/codeRequestToModel.kt +++ b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/codeRequestToModel.kt @@ -246,7 +246,7 @@ private fun Descriptors.FieldDescriptor.modelType(): FieldType { Descriptors.FieldDescriptor.Type.SINT32 -> FieldType.IntegralType.SINT32 Descriptors.FieldDescriptor.Type.SINT64 -> FieldType.IntegralType.SINT64 Descriptors.FieldDescriptor.Type.ENUM -> FieldType.Enum(enumType.toModel()) - Descriptors.FieldDescriptor.Type.MESSAGE -> FieldType.Reference(lazy { messageType!!.toModel().name }) + Descriptors.FieldDescriptor.Type.MESSAGE -> FieldType.Reference(messageType!!.toModel()) Descriptors.FieldDescriptor.Type.GROUP -> error("GROUP type is unsupported") } diff --git a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/FieldType.kt b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/FieldType.kt index 76c64925d..7a33143f0 100644 --- a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/FieldType.kt +++ b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/FieldType.kt @@ -14,7 +14,7 @@ enum class WireType { } sealed interface FieldType { - val defaultValue: String + val defaultValue: String? val wireType: WireType val isPackable: Boolean get() = false @@ -37,8 +37,8 @@ sealed interface FieldType { override val wireType: WireType = WireType.VARINT } - data class Reference(val value: Lazy) : FieldType { - override val defaultValue: String = "null" + data class Reference(val dec: MessageDeclaration) : FieldType { + override val defaultValue: String? = null override val wireType: WireType = WireType.LENGTH_DELIMITED } From 9f9a6a86554106f7baa1a397c98c327095794808 Mon Sep 17 00:00:00 2001 From: Johannes Zottele Date: Wed, 6 Aug 2025 08:55:37 +0200 Subject: [PATCH 02/11] grpc-pb: First successful sub-message test Signed-off-by: Johannes Zottele --- .../kotlin/kotlinx/rpc/grpc/pb/WireEncoder.kt | 2 +- .../kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt | 13 ++++++ .../kotlinx/rpc/grpc/pb/WireEncoder.jvm.kt | 4 +- .../kotlinx/rpc/grpc/pb/WireEncoder.native.kt | 4 +- .../protobuf/ModelToKotlinCommonGenerator.kt | 42 ++++++++++++------- .../rpc/protobuf/codeRequestToModel.kt | 2 +- .../kotlinx/rpc/protobuf/model/FieldType.kt | 2 +- 7 files changed, 46 insertions(+), 23 deletions(-) diff --git a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.kt b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.kt index 69fd70cb0..d5fd1374c 100644 --- a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.kt +++ b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.kt @@ -54,7 +54,7 @@ public interface WireEncoder { public fun writeMessage( fieldNr: Int, value: T, - encode: (WireEncoder) -> Unit + encode: T.(WireEncoder) -> Unit ) } diff --git a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt index 8ca388405..f857f95e3 100644 --- a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt +++ b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt @@ -6,6 +6,8 @@ package kotlinx.rpc.grpc.pb import OneOfMsg import OneOfMsgInternal +import Outer +import OuterInternal import invoke import kotlinx.io.Buffer import kotlinx.rpc.grpc.codec.MessageCodec @@ -166,4 +168,15 @@ class ProtosTest { assertNull(decoded.field) } + @Test + fun testSubMessage() { + val msg = Outer { + inner = Outer.Inner { + field = 12345678 + } + } + val decoded = decodeEncode(msg, OuterInternal.CODEC) + assertEquals(msg.inner.field, decoded.inner.field) + } + } diff --git a/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.jvm.kt b/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.jvm.kt index 6fea0c0af..3c40a2263 100644 --- a/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.jvm.kt +++ b/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.jvm.kt @@ -203,11 +203,11 @@ private class WireEncoderJvm(sink: Sink) : WireEncoder { override fun writeMessage( fieldNr: Int, value: T, - encode: (WireEncoder) -> Unit + encode: T.(WireEncoder) -> Unit ) { codedOutputStream.writeTag(fieldNr, WireType.LENGTH_DELIMITED.ordinal) codedOutputStream.writeInt32NoTag(value._size) - encode(this) + value.encode(this) } private inline fun writePackedInternal( diff --git a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.native.kt b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.native.kt index 249867f6c..d66cef946 100644 --- a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.native.kt +++ b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.native.kt @@ -166,11 +166,11 @@ internal class WireEncoderNative(private val sink: Sink) : WireEncoder { override fun writeMessage( fieldNr: Int, value: T, - encode: (WireEncoder) -> Unit + encode: T.(WireEncoder) -> Unit ) { pw_encoder_write_tag(raw, fieldNr, WireType.LENGTH_DELIMITED.ordinal) pw_encoder_write_int32_no_tag(raw, value._size) - encode(this) + value.encode(this) } } diff --git a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt index af6a3ab3b..9d1d24f67 100644 --- a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt +++ b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt @@ -182,7 +182,7 @@ class ModelToKotlinCommonGenerator( "= null" } - field.type is FieldType.Reference -> "= ${field.type.dec.internalClassFullName()}()" + field.type is FieldType.Message -> "= ${field.type.dec.internalClassFullName()}()" else -> { "= ${field.type.defaultValue}" @@ -319,9 +319,13 @@ class ModelToKotlinCommonGenerator( } } - is FieldType.Reference -> { + is FieldType.Message -> { val internalClassName = fieldType.dec.internalClassFullName() whenCase("tag.fieldNr == ${field.number} && tag.wireType == $PB_PKG.WireType.LENGTH_DELIMITED") { + // check if the the current sub message object + ifBranch(condition = "!msg.presenceMask[${field.presenceIdx}]", ifBlock = { + code("$lvalue = ${fieldType.dec.internalClassFullName()}()") + }) code("decoder.readMessage($lvalue.asInternal(), $internalClassName::decodeWith)") } } @@ -391,7 +395,7 @@ class ModelToKotlinCommonGenerator( is FieldType.Map -> TODO() - is FieldType.Reference -> code("encoder.writeMessage(fieldNr = ${field.number}, value = $valueVar.asInternal()) { encodeWith(it) }") + is FieldType.Message -> code("encoder.writeMessage(fieldNr = ${field.number}, value = $valueVar.asInternal()) { encodeWith(it) }") } } @@ -426,12 +430,11 @@ class ModelToKotlinCommonGenerator( modifiers = "private", contextReceiver = declaration.internalClassFullName(), ) { - val requiredFields = declaration.actualFields - .filter { it.dec.isRequired } + val requiredFields = declaration.actualFields.filter { it.dec.isRequired } + val submessages = declaration.actualFields.filter { it.type is FieldType.Message } - if (requiredFields.isEmpty()) { + if (submessages.isEmpty() && requiredFields.isEmpty()) { code("// no fields to check") - return@function } requiredFields.forEach { field -> @@ -439,6 +442,13 @@ class ModelToKotlinCommonGenerator( code("error(\"${declaration.name.simpleName} is missing required field: ${field.name}\")") }) } + + // check submessages + submessages.forEach { field -> + ifBranch(condition = "presenceMask[${field.presenceIdx}]", ifBlock = { + code("${field.name}.asInternal().checkRequiredFields()") + }) + } } private fun CodeGenerator.generateInternalComputeSize(declaration: MessageDeclaration) { @@ -490,9 +500,9 @@ class ModelToKotlinCommonGenerator( else -> code("result = $valueSize") } - is FieldType.Reference, + is FieldType.Message, FieldType.IntegralType.STRING, - FieldType.IntegralType.BYTES -> code("$valueSize.let { $tagSize + ${int32SizeCall("it")} + it }") + FieldType.IntegralType.BYTES -> code("result += $valueSize.let { $tagSize + ${int32SizeCall("it")} + it }") is FieldType.Map -> TODO() is FieldType.OneOf -> whenBlock("val value = $variable") { @@ -545,7 +555,7 @@ class ModelToKotlinCommonGenerator( is FieldType.Enum -> "$PB_PKG.WireSize.$sizeFunName($variable.number)" is FieldType.Map -> TODO() is FieldType.OneOf -> error("OneOf fields have no direct valueSizeCall") - is FieldType.Reference -> "$variable.asInternal()._size" + is FieldType.Message -> "$variable.asInternal()._size" } } @@ -565,7 +575,7 @@ class ModelToKotlinCommonGenerator( } is FieldType.List -> "$name.isNotEmpty()" - is FieldType.Reference -> "" + is FieldType.Message -> "" is FieldType.Enum -> "${fieldType.defaultValue} != $name" @@ -593,7 +603,7 @@ class ModelToKotlinCommonGenerator( is FieldType.Enum -> "Enum" is FieldType.Map -> null is FieldType.OneOf -> null - is FieldType.Reference -> null + is FieldType.Message -> null } private fun FieldDeclaration.transformToFieldDeclaration(): String { @@ -602,7 +612,7 @@ class ModelToKotlinCommonGenerator( private fun FieldDeclaration.typeFqName(): String { return when (type) { - is FieldType.Reference -> { + is FieldType.Message -> { type.dec.name.safeFullName() } @@ -616,7 +626,7 @@ class ModelToKotlinCommonGenerator( is FieldType.List -> { val fqValue = when (val value = type.value) { - is FieldType.Reference -> value.dec.name + is FieldType.Message -> value.dec.name is FieldType.IntegralType -> value.fqName else -> error("Unsupported type: $value") } @@ -628,13 +638,13 @@ class ModelToKotlinCommonGenerator( val entry by type.entry val fqKey = when (val key = entry.key) { - is FieldType.Reference -> key.dec.name + is FieldType.Message -> key.dec.name is FieldType.IntegralType -> key.fqName else -> error("Unsupported type: $key") } val fqValue = when (val value = entry.value) { - is FieldType.Reference -> value.dec.name + is FieldType.Message -> value.dec.name is FieldType.IntegralType -> value.fqName else -> error("Unsupported type: $value") } diff --git a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/codeRequestToModel.kt b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/codeRequestToModel.kt index 157c01aff..39300685e 100644 --- a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/codeRequestToModel.kt +++ b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/codeRequestToModel.kt @@ -246,7 +246,7 @@ private fun Descriptors.FieldDescriptor.modelType(): FieldType { Descriptors.FieldDescriptor.Type.SINT32 -> FieldType.IntegralType.SINT32 Descriptors.FieldDescriptor.Type.SINT64 -> FieldType.IntegralType.SINT64 Descriptors.FieldDescriptor.Type.ENUM -> FieldType.Enum(enumType.toModel()) - Descriptors.FieldDescriptor.Type.MESSAGE -> FieldType.Reference(messageType!!.toModel()) + Descriptors.FieldDescriptor.Type.MESSAGE -> FieldType.Message(messageType!!.toModel()) Descriptors.FieldDescriptor.Type.GROUP -> error("GROUP type is unsupported") } diff --git a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/FieldType.kt b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/FieldType.kt index 7a33143f0..9fabe0a28 100644 --- a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/FieldType.kt +++ b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/FieldType.kt @@ -37,7 +37,7 @@ sealed interface FieldType { override val wireType: WireType = WireType.VARINT } - data class Reference(val dec: MessageDeclaration) : FieldType { + data class Message(val dec: MessageDeclaration) : FieldType { override val defaultValue: String? = null override val wireType: WireType = WireType.LENGTH_DELIMITED } From 06d4d547480a08a0b9f193b436a34b77668a814c Mon Sep 17 00:00:00 2001 From: Johannes Zottele Date: Wed, 6 Aug 2025 11:09:59 +0200 Subject: [PATCH 03/11] grpc-pb: Support recursive messages Signed-off-by: Johannes Zottele --- .../kotlinx/rpc/grpc/pb/InternalMessage.kt | 30 ++++++++ .../kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt | 51 ++++++++++--- .../src/commonTest/proto/nested.proto | 71 +++++++++++++++++++ .../src/commonTest/proto/recursive.proto | 13 ++++ .../protobuf/ModelToKotlinCommonGenerator.kt | 43 ++++++----- .../rpc/protobuf/codeRequestToModel.kt | 8 ++- .../kotlinx/rpc/protobuf/model/FieldType.kt | 2 +- .../kotlinx/rpc/protobuf/model/model.kt | 1 + 8 files changed, 188 insertions(+), 31 deletions(-) create mode 100644 grpc/grpc-core/src/commonTest/proto/nested.proto create mode 100644 grpc/grpc-core/src/commonTest/proto/recursive.proto diff --git a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/InternalMessage.kt b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/InternalMessage.kt index 370f95f58..bbf929c1f 100644 --- a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/InternalMessage.kt +++ b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/InternalMessage.kt @@ -6,9 +6,39 @@ package kotlinx.rpc.grpc.pb import kotlinx.rpc.grpc.utils.BitSet import kotlinx.rpc.internal.utils.InternalRpcApi +import kotlin.properties.ReadWriteProperty +import kotlin.reflect.KProperty @InternalRpcApi public abstract class InternalMessage(fieldsWithPresence: Int) { public val presenceMask: BitSet = BitSet(fieldsWithPresence) public abstract val _size: Int +} + +public class MsgFieldDelegate( + private val presenceIdx: Int? = null, + private val defaultProvider: (() -> T)? = null +) : ReadWriteProperty { + + private var valueSet = false + private var _value: T? = null + + override operator fun getValue(thisRef: InternalMessage, property: KProperty<*>): T { + if (!valueSet) { + if (defaultProvider != null) { + _value = defaultProvider.invoke() + valueSet = true + } else { + error("Property ${property.name} not initialized") + } + } + @Suppress("UNCHECKED_CAST") + return _value as T + } + + override operator fun setValue(thisRef: InternalMessage, property: KProperty<*>, new: T) { + presenceIdx?.let { thisRef.presenceMask[it] = true } + _value = new + valueSet = true + } } \ No newline at end of file diff --git a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt index f857f95e3..a2ea09b23 100644 --- a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt +++ b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt @@ -11,11 +11,8 @@ import OuterInternal import invoke import kotlinx.io.Buffer import kotlinx.rpc.grpc.codec.MessageCodec -import kotlinx.rpc.grpc.test.Enum -import kotlinx.rpc.grpc.test.UsingEnum -import kotlinx.rpc.grpc.test.UsingEnumInternal +import kotlinx.rpc.grpc.test.* import kotlinx.rpc.grpc.test.common.* -import kotlinx.rpc.grpc.test.invoke import kotlin.test.Test import kotlin.test.assertEquals import kotlin.test.assertFailsWith @@ -23,7 +20,7 @@ import kotlin.test.assertNull class ProtosTest { - private fun decodeEncode( + private fun encodeDecode( msg: M, codec: MessageCodec ): M { @@ -53,7 +50,7 @@ class ProtosTest { val msgObj = msg - val decoded = decodeEncode(msgObj, AllPrimitivesInternal.CODEC) + val decoded = encodeDecode(msgObj, AllPrimitivesInternal.CODEC) assertEquals(msg.double, decoded.double) } @@ -68,7 +65,7 @@ class ProtosTest { listString = listOf("a", "b", "c") } - val decoded = decodeEncode(msg, RepeatedInternal.CODEC) + val decoded = encodeDecode(msg, RepeatedInternal.CODEC) assertEquals(msg.listInt32, decoded.listInt32) assertEquals(msg.listFixed32, decoded.listFixed32) @@ -112,7 +109,7 @@ class ProtosTest { enum = Enum.ONE_SECOND } - val decodedMsg = decodeEncode(msg, UsingEnumInternal.CODEC) + val decodedMsg = encodeDecode(msg, UsingEnumInternal.CODEC) assertEquals(Enum.ONE, decodedMsg.enum) assertEquals(Enum.ONE_SECOND, decodedMsg.enum) } @@ -135,13 +132,13 @@ class ProtosTest { val msg1 = OneOfMsg { field = OneOfMsg.Field.Sint(23) } - val decoded1 = decodeEncode(msg1, OneOfMsgInternal.CODEC) + val decoded1 = encodeDecode(msg1, OneOfMsgInternal.CODEC) assertEquals(OneOfMsg.Field.Sint(23), decoded1.field) val msg2 = OneOfMsg { field = OneOfMsg.Field.Fixed(21u) } - val decoded2 = decodeEncode(msg2, OneOfMsgInternal.CODEC) + val decoded2 = encodeDecode(msg2, OneOfMsgInternal.CODEC) assertEquals(OneOfMsg.Field.Fixed(21u), decoded2.field) } @@ -175,8 +172,40 @@ class ProtosTest { field = 12345678 } } - val decoded = decodeEncode(msg, OuterInternal.CODEC) + val decoded = encodeDecode(msg, OuterInternal.CODEC) assertEquals(msg.inner.field, decoded.inner.field) } + @Test + fun testRecursiveReqNotSet() { + assertFailsWith("RecursiveReq is missing required field: rec") { + val msg = RecursiveReq { + rec = RecursiveReq { + rec = RecursiveReq { + + } + num = 3 + } + } + } + } + + @Test + fun testRecursive() { + val msg = Recursive { + rec = Recursive { + rec = Recursive {} + num = 3 + } + } + + assertEquals(null, msg.rec.rec.rec.rec.num) + assertEquals(3, msg.rec.num) + + val decoded = encodeDecode(msg, RecursiveInternal.CODEC) + + assertEquals(3, decoded.rec.num) + assertEquals(null, decoded.rec.rec.rec.rec.num) + } + } diff --git a/grpc/grpc-core/src/commonTest/proto/nested.proto b/grpc/grpc-core/src/commonTest/proto/nested.proto new file mode 100644 index 000000000..a09ef69f0 --- /dev/null +++ b/grpc/grpc-core/src/commonTest/proto/nested.proto @@ -0,0 +1,71 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2023 Google LLC. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file or at +// https://developers.google.com/open-source/licenses/bsd + +syntax = "proto2"; + +package kotlinx.rpc.grpc.test; + +message Outer { + message Inner { + message InnerSubMsg { + optional bool flag = 1; + } + + enum InnerEnum { + INNER_ENUM_UNSPECIFIED = 0; + INNER_ENUM_FOO = 1; + } + + optional double double = 1; + optional float float = 2; + optional int32 int32 = 3; + optional int64 int64 = 4; + optional uint32 uint32 = 5; + optional uint64 uint64 = 6; + optional sint32 sint32 = 7; + optional sint64 sint64 = 8; + optional fixed32 fixed32 = 9; + optional fixed64 fixed64 = 10; + optional sfixed32 sfixed32 = 11; + optional sfixed64 sfixed64 = 12; + optional bool bool = 13; + optional string string = 14; + optional bytes bytes = 15; + optional InnerSubMsg inner_submsg = 16; + optional InnerEnum inner_enum = 17; + repeated int32 repeated_int32 = 18 [packed = true]; + repeated InnerSubMsg repeated_inner_submsg = 19; + // map string_map = 20; + + message SuperInner { + message DuperInner { + message EvenMoreInner { + message CantBelieveItsSoInner { + optional int32 num = 99; + } + + enum JustWayTooInner { + JUST_WAY_TOO_INNER_UNSPECIFIED = 0; + } + } + } + } + } + // optional Inner inner = 1; + // optional .kotlinx.rpc.grpc.test.Outer.Inner.SuperInner.DuperInner.EvenMoreInner + // .CantBelieveItsSoInner deep = 2; + // + // optional .kotlinx.rpc.grpc.test.Outer.Inner.SuperInner.DuperInner.EvenMoreInner.JustWayTooInner + // deep_enum = 4; + + optional NotInside notinside = 3; +} + +message NotInside { + optional int32 num = 1; +} + diff --git a/grpc/grpc-core/src/commonTest/proto/recursive.proto b/grpc/grpc-core/src/commonTest/proto/recursive.proto new file mode 100644 index 000000000..af40e678b --- /dev/null +++ b/grpc/grpc-core/src/commonTest/proto/recursive.proto @@ -0,0 +1,13 @@ +syntax = "proto2"; + +package kotlinx.rpc.grpc.test; + +message RecursiveReq { + required RecursiveReq rec = 1; + optional int32 num = 2; +} + +message Recursive { + optional Recursive rec = 1; + optional int32 num = 2; +} diff --git a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt index 9d1d24f67..22193034c 100644 --- a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt +++ b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt @@ -182,20 +182,16 @@ class ModelToKotlinCommonGenerator( "= null" } - field.type is FieldType.Message -> "= ${field.type.dec.internalClassFullName()}()" + field.type is FieldType.Message -> + "by MsgFieldDelegate(PresenceIndices.${field.name}) { ${field.type.dec.value.internalClassFullName()}() }" else -> { - "= ${field.type.defaultValue}" + val fieldPresence = if (field.presenceIdx != null) "PresenceIndices.${field.name}" else "" + "by MsgFieldDelegate($fieldPresence) { ${field.type.defaultValue} }" } } code("override var $fieldDeclaration $value") - if (field.presenceIdx != null) { - scope("set(value) ") { - code("presenceMask[PresenceIndices.${field.name}] = true") - code("field = value") - } - } newLine() } @@ -320,11 +316,11 @@ class ModelToKotlinCommonGenerator( } is FieldType.Message -> { - val internalClassName = fieldType.dec.internalClassFullName() + val internalClassName = fieldType.dec.value.internalClassFullName() whenCase("tag.fieldNr == ${field.number} && tag.wireType == $PB_PKG.WireType.LENGTH_DELIMITED") { // check if the the current sub message object ifBranch(condition = "!msg.presenceMask[${field.presenceIdx}]", ifBlock = { - code("$lvalue = ${fieldType.dec.internalClassFullName()}()") + code("$lvalue = ${fieldType.dec.value.internalClassFullName()}()") }) code("decoder.readMessage($lvalue.asInternal(), $internalClassName::decodeWith)") } @@ -356,7 +352,9 @@ class ModelToKotlinCommonGenerator( writeFieldValue(field, field.name) }) } else { - writeFieldValue(field, field.name) + ifBranch(condition = "presenceMask[${field.presenceIdx}]", ifBlock = { + writeFieldValue(field, field.name) + }) } } } @@ -379,7 +377,16 @@ class ModelToKotlinCommonGenerator( })" ) - else -> code("$valueVar.forEach { encoder.write${encFunc!!}($number, it) }") + fieldType.value is FieldType.Message -> scope("$valueVar.forEach") { + code("encoder.writeMessage(fieldNr = ${field.number}, value = it.asInternal()) { encodeWith(it) }") + } + + else -> { + require(encFunc != null) { "No encode function for list type: $fieldType" } + scope("$valueVar.forEach") { + code("encoder.write${encFunc}($number, it)") + } + } } } @@ -470,7 +477,9 @@ class ModelToKotlinCommonGenerator( generateFieldComputeSizeCall(field, fieldName) } } else { - generateFieldComputeSizeCall(field, fieldName) + scope("if (presenceMask[${field.presenceIdx}])") { + generateFieldComputeSizeCall(field, fieldName) + } } } code("return result") @@ -613,7 +622,7 @@ class ModelToKotlinCommonGenerator( private fun FieldDeclaration.typeFqName(): String { return when (type) { is FieldType.Message -> { - type.dec.name.safeFullName() + type.dec.value.name.safeFullName() } is FieldType.Enum -> type.dec.name.safeFullName() @@ -626,7 +635,7 @@ class ModelToKotlinCommonGenerator( is FieldType.List -> { val fqValue = when (val value = type.value) { - is FieldType.Message -> value.dec.name + is FieldType.Message -> value.dec.value.name is FieldType.IntegralType -> value.fqName else -> error("Unsupported type: $value") } @@ -638,13 +647,13 @@ class ModelToKotlinCommonGenerator( val entry by type.entry val fqKey = when (val key = entry.key) { - is FieldType.Message -> key.dec.name + is FieldType.Message -> key.dec.value.name is FieldType.IntegralType -> key.fqName else -> error("Unsupported type: $key") } val fqValue = when (val value = entry.value) { - is FieldType.Message -> value.dec.name + is FieldType.Message -> value.dec.value.name is FieldType.IntegralType -> value.fqName else -> error("Unsupported type: $value") } diff --git a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/codeRequestToModel.kt b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/codeRequestToModel.kt index 39300685e..d9f013944 100644 --- a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/codeRequestToModel.kt +++ b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/codeRequestToModel.kt @@ -9,6 +9,7 @@ import com.google.protobuf.Descriptors import com.google.protobuf.compiler.PluginProtos.CodeGeneratorRequest import kotlinx.rpc.protobuf.model.* +private val nameCache = mutableMapOf() private val modelCache = mutableMapOf() /** @@ -65,9 +66,10 @@ private fun DescriptorProtos.FileDescriptorProto.toDescriptor( * @return The fully qualified name represented as an instance of FqName, specific to the descriptor's context. */ private fun Descriptors.GenericDescriptor.fqName(): FqName { + if (nameCache.containsKey(this)) return nameCache[this]!! val nameCapital = name.simpleProtoNameToKotlin(firstLetterUpper = true) val nameLower = name.simpleProtoNameToKotlin() - return when (this) { + val fqName = when (this) { is Descriptors.FileDescriptor -> FqName.Package.fromString(`package`) is Descriptors.Descriptor -> FqName.Declaration(nameCapital, containingType?.fqName() ?: file.fqName()) is Descriptors.FieldDescriptor -> { @@ -82,6 +84,8 @@ private fun Descriptors.GenericDescriptor.fqName(): FqName { is Descriptors.MethodDescriptor -> FqName.Declaration(nameLower, service?.fqName() ?: file.fqName()) else -> error("Unknown generic descriptor: $this") } + nameCache[this] = fqName + return fqName } /** @@ -246,7 +250,7 @@ private fun Descriptors.FieldDescriptor.modelType(): FieldType { Descriptors.FieldDescriptor.Type.SINT32 -> FieldType.IntegralType.SINT32 Descriptors.FieldDescriptor.Type.SINT64 -> FieldType.IntegralType.SINT64 Descriptors.FieldDescriptor.Type.ENUM -> FieldType.Enum(enumType.toModel()) - Descriptors.FieldDescriptor.Type.MESSAGE -> FieldType.Message(messageType!!.toModel()) + Descriptors.FieldDescriptor.Type.MESSAGE -> FieldType.Message(lazy { messageType!!.toModel() }) Descriptors.FieldDescriptor.Type.GROUP -> error("GROUP type is unsupported") } diff --git a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/FieldType.kt b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/FieldType.kt index 9fabe0a28..d1eda9206 100644 --- a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/FieldType.kt +++ b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/FieldType.kt @@ -37,7 +37,7 @@ sealed interface FieldType { override val wireType: WireType = WireType.VARINT } - data class Message(val dec: MessageDeclaration) : FieldType { + data class Message(val dec: Lazy) : FieldType { override val defaultValue: String? = null override val wireType: WireType = WireType.LENGTH_DELIMITED } diff --git a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/model.kt b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/model.kt index 275f2c104..cac6e4da2 100644 --- a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/model.kt +++ b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/model.kt @@ -79,6 +79,7 @@ data class FieldDeclaration( val nullable: Boolean = (dec.hasPresence() && !dec.isRequired && !dec.hasDefaultValue() && !dec.isRepeated // repeated fields cannot be nullable (just empty) && dec.realContainingOneof == null // upper conditions would match oneof inner fields + && type !is FieldType.Message // messages must not be null (to conform protobuf standards) ) || type is FieldType.OneOf // all OneOf fields are nullable val number: Int = dec.number From 2659375a059af8d005b0bd74628c5aa1bae63e3d Mon Sep 17 00:00:00 2001 From: Johannes Zottele Date: Wed, 6 Aug 2025 12:02:00 +0200 Subject: [PATCH 04/11] grpc-pb: Support nested message Signed-off-by: Johannes Zottele --- .../kotlinx/rpc/grpc/pb/InternalMessage.kt | 7 +- .../kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt | 61 ++++++++++++++-- .../grpc-core/src/commonTest/proto/enum.proto | 4 +- .../src/commonTest/proto/nested.proto | 73 ++++++++++--------- .../src/commonTest/proto/recursive.proto | 2 +- .../protobuf/ModelToKotlinCommonGenerator.kt | 25 +++---- 6 files changed, 112 insertions(+), 60 deletions(-) diff --git a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/InternalMessage.kt b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/InternalMessage.kt index bbf929c1f..32383c2d2 100644 --- a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/InternalMessage.kt +++ b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/InternalMessage.kt @@ -12,6 +12,8 @@ import kotlin.reflect.KProperty @InternalRpcApi public abstract class InternalMessage(fieldsWithPresence: Int) { public val presenceMask: BitSet = BitSet(fieldsWithPresence) + + @Suppress("PropertyName") public abstract val _size: Int } @@ -32,13 +34,12 @@ public class MsgFieldDelegate( error("Property ${property.name} not initialized") } } - @Suppress("UNCHECKED_CAST") return _value as T } - override operator fun setValue(thisRef: InternalMessage, property: KProperty<*>, new: T) { + override operator fun setValue(thisRef: InternalMessage, property: KProperty<*>, value: T) { presenceIdx?.let { thisRef.presenceMask[it] = true } - _value = new + _value = value valueSet = true } } \ No newline at end of file diff --git a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt index a2ea09b23..95c07cdf1 100644 --- a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt +++ b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt @@ -11,8 +11,16 @@ import OuterInternal import invoke import kotlinx.io.Buffer import kotlinx.rpc.grpc.codec.MessageCodec -import kotlinx.rpc.grpc.test.* +import kotlinx.rpc.grpc.test.MyEnum +import kotlinx.rpc.grpc.test.UsingEnum +import kotlinx.rpc.grpc.test.UsingEnumInternal import kotlinx.rpc.grpc.test.common.* +import kotlinx.rpc.grpc.test.invoke +import test.nested.* +import test.recursive.Recursive +import test.recursive.RecursiveInternal +import test.recursive.RecursiveReq +import test.recursive.invoke import kotlin.test.Test import kotlin.test.assertEquals import kotlin.test.assertFailsWith @@ -100,18 +108,18 @@ class ProtosTest { encoder.flush() val decodedMsg = UsingEnumInternal.CODEC.decode(buffer) - assertEquals(Enum.UNRECOGNIZED(50), decodedMsg.enum) + assertEquals(MyEnum.UNRECOGNIZED(50), decodedMsg.enum) } @Test fun testEnumAlias() { val msg = UsingEnum { - enum = Enum.ONE_SECOND + enum = MyEnum.ONE_SECOND } val decodedMsg = encodeDecode(msg, UsingEnumInternal.CODEC) - assertEquals(Enum.ONE, decodedMsg.enum) - assertEquals(Enum.ONE_SECOND, decodedMsg.enum) + assertEquals(MyEnum.ONE, decodedMsg.enum) + assertEquals(MyEnum.ONE_SECOND, decodedMsg.enum) } @Test @@ -124,7 +132,7 @@ class ProtosTest { assertEquals(0, buffer.size) val decoded = UsingEnumInternal.CODEC.decode(buffer) - assertEquals(Enum.ZERO, decoded.enum) + assertEquals(MyEnum.ZERO, decoded.enum) } @Test @@ -208,4 +216,45 @@ class ProtosTest { assertEquals(null, decoded.rec.rec.rec.rec.num) } + @Test + fun testNested() { + val inner = NestedOuter.Inner.SuperInner.DuperInner.EvenMoreInner.CantBelieveItsSoInner { + num = 123456789 + } + + val notInside = NotInside { + num = -12 + } + val outer = NestedOuter { + deep = inner + deepEnum = + NestedOuter.Inner.SuperInner.DuperInner.EvenMoreInner.JustWayTooInner.JUST_WAY_TOO_INNER_UNSPECIFIED + } + + assertEquals(123456789, outer.deep.num) + assertEquals( + NestedOuter.Inner.SuperInner.DuperInner.EvenMoreInner.JustWayTooInner.JUST_WAY_TOO_INNER_UNSPECIFIED, + outer.deepEnum + ) + assertEquals(-12, notInside.num) + + val decodedOuter = encodeDecode(outer, NestedOuterInternal.CODEC) + assertEquals(123456789, decodedOuter.deep.num) + assertEquals( + NestedOuter.Inner.SuperInner.DuperInner.EvenMoreInner.JustWayTooInner.JUST_WAY_TOO_INNER_UNSPECIFIED, + decodedOuter.deepEnum + ) + assertEquals(-12, notInside.num) + + val decodedNotInside = encodeDecode(notInside, NotInsideInternal.CODEC) + assertEquals(-12, decodedNotInside.num) + + val decodedInner = encodeDecode( + inner, + NestedOuterInternal.InnerInternal.SuperInnerInternal.DuperInnerInternal.EvenMoreInnerInternal.CantBelieveItsSoInnerInternal.CODEC + ) + assertEquals(123456789, decodedInner.num) + + } + } diff --git a/grpc/grpc-core/src/commonTest/proto/enum.proto b/grpc/grpc-core/src/commonTest/proto/enum.proto index 83db7b7fe..2a3a8e716 100644 --- a/grpc/grpc-core/src/commonTest/proto/enum.proto +++ b/grpc/grpc-core/src/commonTest/proto/enum.proto @@ -2,7 +2,7 @@ syntax = "proto3"; package kotlinx.rpc.grpc.test; -enum Enum { +enum MyEnum { option allow_alias = true; ZERO = 0; ONE = 1; @@ -12,5 +12,5 @@ enum Enum { } message UsingEnum { - Enum enum = 1; + MyEnum enum = 1; } diff --git a/grpc/grpc-core/src/commonTest/proto/nested.proto b/grpc/grpc-core/src/commonTest/proto/nested.proto index a09ef69f0..eacefa903 100644 --- a/grpc/grpc-core/src/commonTest/proto/nested.proto +++ b/grpc/grpc-core/src/commonTest/proto/nested.proto @@ -7,38 +7,38 @@ syntax = "proto2"; -package kotlinx.rpc.grpc.test; +package test.nested; -message Outer { +message NestedOuter { message Inner { - message InnerSubMsg { - optional bool flag = 1; - } - - enum InnerEnum { - INNER_ENUM_UNSPECIFIED = 0; - INNER_ENUM_FOO = 1; - } + // message InnerSubMsg { + // optional bool flag = 1; + // } + // + // enum InnerEnum { + // INNER_ENUM_UNSPECIFIED = 0; + // INNER_ENUM_FOO = 1; + // } - optional double double = 1; - optional float float = 2; - optional int32 int32 = 3; - optional int64 int64 = 4; - optional uint32 uint32 = 5; - optional uint64 uint64 = 6; - optional sint32 sint32 = 7; - optional sint64 sint64 = 8; - optional fixed32 fixed32 = 9; - optional fixed64 fixed64 = 10; - optional sfixed32 sfixed32 = 11; - optional sfixed64 sfixed64 = 12; - optional bool bool = 13; - optional string string = 14; - optional bytes bytes = 15; - optional InnerSubMsg inner_submsg = 16; - optional InnerEnum inner_enum = 17; - repeated int32 repeated_int32 = 18 [packed = true]; - repeated InnerSubMsg repeated_inner_submsg = 19; + // optional double double = 1; + // optional float float = 2; + // optional int32 int32 = 3; + // optional int64 int64 = 4; + // optional uint32 uint32 = 5; + // optional uint64 uint64 = 6; + // optional sint32 sint32 = 7; + // optional sint64 sint64 = 8; + // optional fixed32 fixed32 = 9; + // optional fixed64 fixed64 = 10; + // optional sfixed32 sfixed32 = 11; + // optional sfixed64 sfixed64 = 12; + // optional bool bool = 13; + // optional string string = 14; + // optional bytes bytes = 15; + // optional InnerSubMsg inner_submsg = 16; + // optional InnerEnum inner_enum = 17; + // repeated int32 repeated_int32 = 18 [packed = true]; + // repeated InnerSubMsg repeated_inner_submsg = 19; // map string_map = 20; message SuperInner { @@ -55,16 +55,19 @@ message Outer { } } } - // optional Inner inner = 1; - // optional .kotlinx.rpc.grpc.test.Outer.Inner.SuperInner.DuperInner.EvenMoreInner - // .CantBelieveItsSoInner deep = 2; - // - // optional .kotlinx.rpc.grpc.test.Outer.Inner.SuperInner.DuperInner.EvenMoreInner.JustWayTooInner - // deep_enum = 4; + optional Inner inner = 1; + optional .test.nested.NestedOuter.Inner.SuperInner.DuperInner.EvenMoreInner + .CantBelieveItsSoInner deep = 2; + optional .test.nested.NestedOuter.Inner.SuperInner.DuperInner.EvenMoreInner.JustWayTooInner + deep_enum = 4; optional NotInside notinside = 3; } +enum OutsideEnum { + JUST_WAY_TOO_INNER_UNSPECIFIED = 0; +} + message NotInside { optional int32 num = 1; } diff --git a/grpc/grpc-core/src/commonTest/proto/recursive.proto b/grpc/grpc-core/src/commonTest/proto/recursive.proto index af40e678b..79103ca4f 100644 --- a/grpc/grpc-core/src/commonTest/proto/recursive.proto +++ b/grpc/grpc-core/src/commonTest/proto/recursive.proto @@ -1,6 +1,6 @@ syntax = "proto2"; -package kotlinx.rpc.grpc.test; +package test.recursive; message RecursiveReq { required RecursiveReq rec = 1; diff --git a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt index 22193034c..b1a0660b7 100644 --- a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt +++ b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt @@ -112,7 +112,7 @@ class ModelToKotlinCommonGenerator( messages.forEach { generateInternalMessage(it) } // emit all required functions in the outer scope - val allMsgs = messages + messages.flatMap(MessageDeclaration::nestedDeclarations) + val allMsgs = messages + messages.flatMap(MessageDeclaration::allNestedRecursively) allMsgs.forEach { generateMessageConstructor(it) } @@ -713,21 +713,20 @@ class ModelToKotlinCommonGenerator( superTypes = listOf("$className(number)"), ) - if (declaration.aliases.isNotEmpty()) { - newLine() - - clazz("", modifiers = "companion", declarationType = DeclarationType.Object) { - declaration.aliases.forEach { alias: EnumDeclaration.Alias -> - code( - "val ${alias.name.simpleName}: $className " + - "get() = ${alias.original.name.simpleName}" - ) - } + newLine() - val entryNamesSorted = entriesSorted.joinToString(", ") { it.name.simpleName } - code("val entries: List<$className> by lazy { listOf($entryNamesSorted) }") + clazz("", modifiers = "companion", declarationType = DeclarationType.Object) { + declaration.aliases.forEach { alias: EnumDeclaration.Alias -> + code( + "val ${alias.name.simpleName}: $className " + + "get() = ${alias.original.name.simpleName}" + ) } + + val entryNamesSorted = entriesSorted.joinToString(", ") { it.name.simpleName } + code("val entries: List<$className> by lazy { listOf($entryNamesSorted) }") } + } } From 987f15958377aa20563a40f1e04553c28c899a4d Mon Sep 17 00:00:00 2001 From: Johannes Zottele Date: Wed, 6 Aug 2025 14:39:58 +0200 Subject: [PATCH 05/11] grpc-pb: Support repeated messages Signed-off-by: Johannes Zottele --- .../kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt | 32 +++++++ .../commonTest/proto/exclude/reference.proto | 9 -- .../src/commonTest/proto/nested.proto | 54 ++++++------ .../src/commonTest/proto/repeated.proto | 5 ++ .../src/commonTest/proto/sub_message.proto | 13 +++ .../protobuf/ModelToKotlinCommonGenerator.kt | 88 ++++++++++++++----- 6 files changed, 143 insertions(+), 58 deletions(-) delete mode 100644 grpc/grpc-core/src/commonTest/proto/exclude/reference.proto create mode 100644 grpc/grpc-core/src/commonTest/proto/sub_message.proto diff --git a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt index 95c07cdf1..5a5cc4ba4 100644 --- a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt +++ b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt @@ -21,6 +21,7 @@ import test.recursive.Recursive import test.recursive.RecursiveInternal import test.recursive.RecursiveReq import test.recursive.invoke +import test.submsg.* import kotlin.test.Test import kotlin.test.assertEquals import kotlin.test.assertFailsWith @@ -65,12 +66,14 @@ class ProtosTest { @Test fun testRepeatedProto() { + val elem = { i: Int -> Repeated.Other { a = i } } val msg = Repeated { listFixed32 = listOf(1, 5, 3).map { it.toUInt() } listFixed32Packed = listOf(1, 2, 3).map { it.toUInt() } listInt32 = listOf(4, 7, 6) listInt32Packed = listOf(4, 5, 6) listString = listOf("a", "b", "c") + listMessage = listOf(elem(1), elem(2), elem(3)) } val decoded = encodeDecode(msg, RepeatedInternal.CODEC) @@ -78,6 +81,10 @@ class ProtosTest { assertEquals(msg.listInt32, decoded.listInt32) assertEquals(msg.listFixed32, decoded.listFixed32) assertEquals(msg.listString, decoded.listString) + assertEquals(msg.listMessage.size, decoded.listMessage.size) + for (i in msg.listMessage.indices) { + assertEquals(msg.listMessage[i].a, decoded.listMessage[i].a) + } } @Test @@ -254,7 +261,32 @@ class ProtosTest { NestedOuterInternal.InnerInternal.SuperInnerInternal.DuperInnerInternal.EvenMoreInnerInternal.CantBelieveItsSoInnerInternal.CODEC ) assertEquals(123456789, decodedInner.num) + } + + @Test + fun testMessageMerging() { + + val buffer = Buffer() + val encoder = WireEncoder(buffer) + + val firstPart = Other { + arg1 = "first" + arg2 = "second" + } + val secondPart = Other { + arg2 = "third" + arg3 = "fourth" + } + + encoder.writeMessage(1, firstPart as OtherInternal) { encodeWith(encoder) } + encoder.flush() + encoder.writeMessage(1, secondPart as OtherInternal) { encodeWith(encoder) } + encoder.flush() + val decoded = ReferenceInternal.CODEC.decode(buffer) + assertEquals("first", decoded.other.arg1) + assertEquals("third", decoded.other.arg2) + assertEquals("fourth", decoded.other.arg3) } } diff --git a/grpc/grpc-core/src/commonTest/proto/exclude/reference.proto b/grpc/grpc-core/src/commonTest/proto/exclude/reference.proto deleted file mode 100644 index 4a68c5189..000000000 --- a/grpc/grpc-core/src/commonTest/proto/exclude/reference.proto +++ /dev/null @@ -1,9 +0,0 @@ -syntax = "proto3"; - -message Other { - string arg = 1; -} - -message References { - Other other = 2; -} diff --git a/grpc/grpc-core/src/commonTest/proto/nested.proto b/grpc/grpc-core/src/commonTest/proto/nested.proto index eacefa903..18e0bf6cd 100644 --- a/grpc/grpc-core/src/commonTest/proto/nested.proto +++ b/grpc/grpc-core/src/commonTest/proto/nested.proto @@ -11,34 +11,34 @@ package test.nested; message NestedOuter { message Inner { - // message InnerSubMsg { - // optional bool flag = 1; - // } - // - // enum InnerEnum { - // INNER_ENUM_UNSPECIFIED = 0; - // INNER_ENUM_FOO = 1; - // } + message InnerSubMsg { + optional bool flag = 1; + } + + enum InnerEnum { + INNER_ENUM_UNSPECIFIED = 0; + INNER_ENUM_FOO = 1; + } - // optional double double = 1; - // optional float float = 2; - // optional int32 int32 = 3; - // optional int64 int64 = 4; - // optional uint32 uint32 = 5; - // optional uint64 uint64 = 6; - // optional sint32 sint32 = 7; - // optional sint64 sint64 = 8; - // optional fixed32 fixed32 = 9; - // optional fixed64 fixed64 = 10; - // optional sfixed32 sfixed32 = 11; - // optional sfixed64 sfixed64 = 12; - // optional bool bool = 13; - // optional string string = 14; - // optional bytes bytes = 15; - // optional InnerSubMsg inner_submsg = 16; - // optional InnerEnum inner_enum = 17; - // repeated int32 repeated_int32 = 18 [packed = true]; - // repeated InnerSubMsg repeated_inner_submsg = 19; + optional double double = 1; + optional float float = 2; + optional int32 int32 = 3; + optional int64 int64 = 4; + optional uint32 uint32 = 5; + optional uint64 uint64 = 6; + optional sint32 sint32 = 7; + optional sint64 sint64 = 8; + optional fixed32 fixed32 = 9; + optional fixed64 fixed64 = 10; + optional sfixed32 sfixed32 = 11; + optional sfixed64 sfixed64 = 12; + optional bool bool = 13; + optional string string = 14; + optional bytes bytes = 15; + optional InnerSubMsg inner_submsg = 16; + optional InnerEnum inner_enum = 17; + repeated int32 repeated_int32 = 18 [packed = true]; + repeated InnerSubMsg repeated_inner_submsg = 19; // map string_map = 20; message SuperInner { diff --git a/grpc/grpc-core/src/commonTest/proto/repeated.proto b/grpc/grpc-core/src/commonTest/proto/repeated.proto index e80c7b445..fad5caf75 100644 --- a/grpc/grpc-core/src/commonTest/proto/repeated.proto +++ b/grpc/grpc-core/src/commonTest/proto/repeated.proto @@ -8,4 +8,9 @@ message Repeated { repeated int32 listInt32 = 3 [packed = false]; repeated int32 listInt32Packed = 4 [packed = true]; repeated string listString = 5; + repeated Other listMessage = 6; + + message Other { + int32 a = 1; + } } \ No newline at end of file diff --git a/grpc/grpc-core/src/commonTest/proto/sub_message.proto b/grpc/grpc-core/src/commonTest/proto/sub_message.proto new file mode 100644 index 000000000..f894dffb2 --- /dev/null +++ b/grpc/grpc-core/src/commonTest/proto/sub_message.proto @@ -0,0 +1,13 @@ +syntax = "proto3"; + +package test.submsg; + +message Other { + optional string arg1 = 1; + optional string arg2 = 2; + optional string arg3 = 3; +} + +message Reference { + Other other = 1; +} diff --git a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt index b1a0660b7..4c735fd7b 100644 --- a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt +++ b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt @@ -284,24 +284,21 @@ class ModelToKotlinCommonGenerator( ) { when (val fieldType = field.type) { is FieldType.IntegralType -> whenCase("tag.fieldNr == ${field.number} && tag.wireType == $PB_PKG.WireType.${field.type.wireType.name}") { - val raw = "decoder.read${field.type.decodeEncodeFuncName()}()" - code("$lvalue = ${wrapperCtor(raw)}") + generateDecodeFieldValue(fieldType, lvalue, wrapperCtor = wrapperCtor) } is FieldType.List -> if (field.dec.isPacked) { whenCase("tag.fieldNr == ${field.number} && tag.wireType == $PB_PKG.WireType.LENGTH_DELIMITED") { - code("$lvalue = decoder.readPacked${fieldType.value.decodeEncodeFuncName()}()") + generateDecodeFieldValue(fieldType, lvalue, isPacked = true, wrapperCtor = wrapperCtor) } } else { whenCase("tag.fieldNr == ${field.number} && tag.wireType == $PB_PKG.WireType.${fieldType.value.wireType.name}") { - code("(msg.${field.name} as ArrayList).add(decoder.read${fieldType.value.decodeEncodeFuncName()}())") + generateDecodeFieldValue(fieldType, lvalue, isPacked = false, wrapperCtor = wrapperCtor) } } is FieldType.Enum -> whenCase("tag.fieldNr == ${field.number} && tag.wireType == $PB_PKG.WireType.VARINT") { - val fromNum = "${fieldType.dec.name.safeFullName()}.fromNumber" - val raw = "$fromNum(decoder.read${field.type.decodeEncodeFuncName()}())" - code("$lvalue = ${wrapperCtor(raw)}") + generateDecodeFieldValue(fieldType, lvalue, wrapperCtor = wrapperCtor) } is FieldType.OneOf -> { @@ -316,13 +313,12 @@ class ModelToKotlinCommonGenerator( } is FieldType.Message -> { - val internalClassName = fieldType.dec.value.internalClassFullName() whenCase("tag.fieldNr == ${field.number} && tag.wireType == $PB_PKG.WireType.LENGTH_DELIMITED") { // check if the the current sub message object ifBranch(condition = "!msg.presenceMask[${field.presenceIdx}]", ifBlock = { code("$lvalue = ${fieldType.dec.value.internalClassFullName()}()") }) - code("decoder.readMessage($lvalue.asInternal(), $internalClassName::decodeWith)") + generateDecodeFieldValue(fieldType, lvalue, wrapperCtor = wrapperCtor) } } @@ -330,6 +326,58 @@ class ModelToKotlinCommonGenerator( } } + private fun CodeGenerator.generateDecodeFieldValue( + fieldType: FieldType, + lvalue: String, + isPacked: Boolean = false, + wrapperCtor: (String) -> String = { it } + ) { + when (fieldType) { + is FieldType.IntegralType -> { + val raw = "decoder.read${fieldType.decodeEncodeFuncName()}()" + code("$lvalue = ${wrapperCtor(raw)}") + } + + is FieldType.List -> if (isPacked) { + code("$lvalue = decoder.readPacked${fieldType.value.decodeEncodeFuncName()}()") + } else { + when (val elemType = fieldType.value) { + is FieldType.Message -> { + code("val elem = ${elemType.dec.value.internalClassFullName()}()") + generateDecodeFieldValue(fieldType.value, "elem", wrapperCtor = wrapperCtor) + } + + else -> generateDecodeFieldValue(fieldType.value, "val elem", wrapperCtor = wrapperCtor) + } + code("($lvalue as ArrayList).add(elem)") + } + + is FieldType.Enum -> { + val fromNum = "${fieldType.dec.name.safeFullName()}.fromNumber" + val raw = "$fromNum(decoder.read${fieldType.decodeEncodeFuncName()}())" + code("$lvalue = ${wrapperCtor(raw)}") + } + + is FieldType.OneOf -> { + fieldType.dec.variants.forEach { variant -> + val variantName = "${fieldType.dec.name.safeFullName()}.${variant.name}" + readMatchCase( + field = variant, + lvalue = lvalue, + wrapperCtor = { "$variantName($it)" } + ) + } + } + + is FieldType.Message -> { + val internalClassName = fieldType.dec.value.internalClassFullName() + code("decoder.readMessage($lvalue.asInternal(), $internalClassName::decodeWith)") + } + + is FieldType.Map -> TODO() + } + } + private fun CodeGenerator.generateMessageEncoder(declaration: MessageDeclaration) = function( name = "encodeWith", modifiers = "internal", @@ -373,7 +421,7 @@ class ModelToKotlinCommonGenerator( field.dec.isPacked && !field.packedFixedSize -> code( "encoder.writePacked${encFunc!!}(fieldNr = $number, value = $valueVar, fieldSize = ${ - field.valueSizeCall(valueVar) + field.type.valueSizeCall(valueVar, number, true) })" ) @@ -499,7 +547,7 @@ class ModelToKotlinCommonGenerator( private fun CodeGenerator.generateFieldComputeSizeCall(field: FieldDeclaration, variable: String) { - val valueSize by lazy { field.valueSizeCall(variable) } + val valueSize by lazy { field.type.valueSizeCall(variable, field.number, field.dec.isPacked) } val tagSize = tagSizeCall(field.number, field.type.wireType) when (field.type) { @@ -540,24 +588,20 @@ class ModelToKotlinCommonGenerator( } } - private fun FieldDeclaration.valueSizeCall(variable: String): String { - val sizeFunName = type.decodeEncodeFuncName()?.decapitalize() + private fun FieldType.valueSizeCall(variable: String, number: Int, isPacked: Boolean = false): String { + val sizeFunName = decodeEncodeFuncName()?.decapitalize() val sizeFunc = "$PB_PKG.WireSize.$sizeFunName($variable)" - return when (type) { + return when (this) { is FieldType.IntegralType -> sizeFunc is FieldType.List -> when { - dec.isPacked -> sizeFunc + isPacked -> sizeFunc else -> { // calculate the size of the values within the list. - val valueTypeSizeFunc = type.value.decodeEncodeFuncName()?.decapitalize() - "$variable.sumOf { $PB_PKG.WireSize.$valueTypeSizeFunc(it) + ${ - tagSizeCall( - number, - type.value.wireType - ) - } }" + val valueSize = value.valueSizeCall("it", number) + val tagSize = tagSizeCall(number, value.wireType) + "$variable.sumOf { $valueSize + $tagSize }" } } From 2456f799badf65eb438e761e1c2c04cb65a43a65 Mon Sep 17 00:00:00 2001 From: Johannes Zottele Date: Wed, 6 Aug 2025 14:47:07 +0200 Subject: [PATCH 06/11] grpc-pb: Make generated methods internal Signed-off-by: Johannes Zottele --- .../main/kotlin/kotlinx/rpc/protobuf/CodeGenerator.kt | 4 ++++ .../rpc/protobuf/ModelToKotlinCommonGenerator.kt | 11 ++++++++--- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/CodeGenerator.kt b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/CodeGenerator.kt index ebbadd12f..0dd8db665 100644 --- a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/CodeGenerator.kt +++ b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/CodeGenerator.kt @@ -177,9 +177,13 @@ open class CodeGenerator( typeParameters: String = "", args: String = "", contextReceiver: String = "", + annotations: List = emptyList(), returnType: String = "", block: (CodeGenerator.() -> Unit)? = null, ) { + for (annotation in annotations) { + addLine(annotation) + } val modifiersString = if (modifiers.isEmpty()) "" else "$modifiers " val contextString = if (contextReceiver.isEmpty()) "" else "$contextReceiver." val returnTypeString = if (returnType.isEmpty()) "" else ": $returnType" diff --git a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt index 4c735fd7b..f54584537 100644 --- a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt +++ b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt @@ -260,6 +260,7 @@ class ModelToKotlinCommonGenerator( name = "decodeWith", modifiers = "internal", args = "msg: ${declaration.internalClassFullName()}, decoder: $PB_PKG.WireDecoder", + annotations = listOf("@$INTERNAL_RPC_API_ANNO"), contextReceiver = "${declaration.internalClassFullName()}.CODEC" ) { whileBlock("!decoder.hadError()") { @@ -381,6 +382,7 @@ class ModelToKotlinCommonGenerator( private fun CodeGenerator.generateMessageEncoder(declaration: MessageDeclaration) = function( name = "encodeWith", modifiers = "internal", + annotations = listOf("@$INTERNAL_RPC_API_ANNO"), args = "encoder: $PB_PKG.WireEncoder", contextReceiver = declaration.internalClassFullName(), ) { @@ -459,8 +461,9 @@ class ModelToKotlinCommonGenerator( private fun CodeGenerator.generateInternalEnumConstructor(enum: EnumDeclaration) { function( "fromNumber", - modifiers = "private", + modifiers = "internal", args = "number: Int", + annotations = listOf("@$INTERNAL_RPC_API_ANNO"), contextReceiver = "${enum.name.safeFullName()}.Companion", returnType = enum.name.safeFullName(), ) { @@ -482,7 +485,8 @@ class ModelToKotlinCommonGenerator( */ private fun CodeGenerator.generateRequiredCheck(declaration: MessageDeclaration) = function( name = "checkRequiredFields", - modifiers = "private", + modifiers = "internal", + annotations = listOf("@$INTERNAL_RPC_API_ANNO"), contextReceiver = declaration.internalClassFullName(), ) { val requiredFields = declaration.actualFields.filter { it.dec.isRequired } @@ -537,7 +541,8 @@ class ModelToKotlinCommonGenerator( private fun CodeGenerator.generateInternalCastExtension(declaration: MessageDeclaration) { function( "asInternal", - modifiers = "private", + modifiers = "internal", + annotations = listOf("@$INTERNAL_RPC_API_ANNO"), contextReceiver = declaration.name.safeFullName(), returnType = declaration.internalClassFullName(), ) { From b6cb8d43bd7eba2f8e3cf22c0034535e5701a2d1 Mon Sep 17 00:00:00 2001 From: Johannes Zottele Date: Wed, 6 Aug 2025 15:38:13 +0200 Subject: [PATCH 07/11] grpc-pb: Support message in oneof Signed-off-by: Johannes Zottele --- .../kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt | 14 ++++-- .../src/commonTest/proto/oneof.proto | 3 ++ .../protobuf/ModelToKotlinCommonGenerator.kt | 47 +++++++++++++++---- .../kotlinx/rpc/protobuf/model/model.kt | 4 +- 4 files changed, 53 insertions(+), 15 deletions(-) diff --git a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt index 5a5cc4ba4..6e233ec8f 100644 --- a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt +++ b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt @@ -22,10 +22,7 @@ import test.recursive.RecursiveInternal import test.recursive.RecursiveReq import test.recursive.invoke import test.submsg.* -import kotlin.test.Test -import kotlin.test.assertEquals -import kotlin.test.assertFailsWith -import kotlin.test.assertNull +import kotlin.test.* class ProtosTest { @@ -155,6 +152,15 @@ class ProtosTest { } val decoded2 = encodeDecode(msg2, OneOfMsgInternal.CODEC) assertEquals(OneOfMsg.Field.Fixed(21u), decoded2.field) + + val msg3 = OneOfMsg { + field = OneOfMsg.Field.Other(Other { arg2 = "test" }) + } + val decoded3 = encodeDecode(msg3, OneOfMsgInternal.CODEC) + assertIs(decoded3.field) + assertNull((decoded3.field as OneOfMsg.Field.Other).value.arg1) + assertEquals("test", (decoded3.field as OneOfMsg.Field.Other).value.arg2) + assertNull((decoded3.field as OneOfMsg.Field.Other).value.arg3) } @Test diff --git a/grpc/grpc-core/src/commonTest/proto/oneof.proto b/grpc/grpc-core/src/commonTest/proto/oneof.proto index bde4b3883..cc4c3d217 100644 --- a/grpc/grpc-core/src/commonTest/proto/oneof.proto +++ b/grpc/grpc-core/src/commonTest/proto/oneof.proto @@ -1,6 +1,9 @@ +import "sub_message.proto"; + message OneOfMsg { oneof field { int32 sint = 2; fixed64 fixed = 3; + test.submsg.Other other = 4; } } \ No newline at end of file diff --git a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt index f54584537..23ef8933f 100644 --- a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt +++ b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt @@ -281,44 +281,68 @@ class ModelToKotlinCommonGenerator( private fun CodeGenerator.readMatchCase( field: FieldDeclaration, lvalue: String = "msg.${field.name}", - wrapperCtor: (String) -> String = { it } + wrapperCtor: (String) -> String = { it }, + beforeValueDecoding: CodeGenerator.() -> Unit = {}, ) { when (val fieldType = field.type) { is FieldType.IntegralType -> whenCase("tag.fieldNr == ${field.number} && tag.wireType == $PB_PKG.WireType.${field.type.wireType.name}") { + beforeValueDecoding() generateDecodeFieldValue(fieldType, lvalue, wrapperCtor = wrapperCtor) } is FieldType.List -> if (field.dec.isPacked) { whenCase("tag.fieldNr == ${field.number} && tag.wireType == $PB_PKG.WireType.LENGTH_DELIMITED") { + beforeValueDecoding() generateDecodeFieldValue(fieldType, lvalue, isPacked = true, wrapperCtor = wrapperCtor) } } else { whenCase("tag.fieldNr == ${field.number} && tag.wireType == $PB_PKG.WireType.${fieldType.value.wireType.name}") { + beforeValueDecoding() generateDecodeFieldValue(fieldType, lvalue, isPacked = false, wrapperCtor = wrapperCtor) } } is FieldType.Enum -> whenCase("tag.fieldNr == ${field.number} && tag.wireType == $PB_PKG.WireType.VARINT") { + beforeValueDecoding() generateDecodeFieldValue(fieldType, lvalue, wrapperCtor = wrapperCtor) } is FieldType.OneOf -> { fieldType.dec.variants.forEach { variant -> val variantName = "${fieldType.dec.name.safeFullName()}.${variant.name}" - readMatchCase( - field = variant, - lvalue = lvalue, - wrapperCtor = { "$variantName($it)" } - ) + if (variant.type is FieldType.Message) { + // in case of a message, we must construct an empty message before reading the message + readMatchCase( + field = variant, + lvalue = "field.value", + beforeValueDecoding = { + beforeValueDecoding() + scope("val field = ($lvalue as? $variantName) ?: $variantName(${variant.type.internalCtor()}).also") { + // write the constructed oneof variant to the field + code("$lvalue = it") + } + }) + } else { + readMatchCase( + field = variant, + lvalue = lvalue, + wrapperCtor = { "$variantName($it)" }, + beforeValueDecoding = beforeValueDecoding + ) + } } } is FieldType.Message -> { whenCase("tag.fieldNr == ${field.number} && tag.wireType == $PB_PKG.WireType.LENGTH_DELIMITED") { - // check if the the current sub message object - ifBranch(condition = "!msg.presenceMask[${field.presenceIdx}]", ifBlock = { - code("$lvalue = ${fieldType.dec.value.internalClassFullName()}()") - }) + if (field.presenceIdx != null) { + // check if the current sub message object was already set, if not, set a new one + // to set the field's presence tracker to true + ifBranch(condition = "!msg.presenceMask[${field.presenceIdx}]", ifBlock = { + code("$lvalue = ${fieldType.dec.value.internalClassFullName()}()") + }) + } + beforeValueDecoding() generateDecodeFieldValue(fieldType, lvalue, wrapperCtor = wrapperCtor) } } @@ -819,6 +843,9 @@ class ModelToKotlinCommonGenerator( } } + private fun FieldType.Message.internalCtor() = + dec.value.internalClassFullName() + "()" + private fun MessageDeclaration.internalClassFullName(): String { return name.safeFullName(MSG_INTERNAL_SUFFIX) } diff --git a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/model.kt b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/model.kt index cac6e4da2..0ad688e3e 100644 --- a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/model.kt +++ b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/model.kt @@ -75,10 +75,12 @@ data class FieldDeclaration( ) { val packedFixedSize = type.wireType == WireType.FIXED64 || type.wireType == WireType.FIXED32 + val isPartOfOneof: Boolean = dec.realContainingOneof != null + // aligns with edition settings and backward compatibility with proto2 and proto3 val nullable: Boolean = (dec.hasPresence() && !dec.isRequired && !dec.hasDefaultValue() && !dec.isRepeated // repeated fields cannot be nullable (just empty) - && dec.realContainingOneof == null // upper conditions would match oneof inner fields + && !isPartOfOneof // upper conditions would match oneof inner fields && type !is FieldType.Message // messages must not be null (to conform protobuf standards) ) || type is FieldType.OneOf // all OneOf fields are nullable From 1454a8f5515f8ce083211e44b63f8d81fd98b439 Mon Sep 17 00:00:00 2001 From: Johannes Zottele Date: Wed, 6 Aug 2025 15:46:47 +0200 Subject: [PATCH 08/11] grpc-pb: Add test for enum in oneOf Signed-off-by: Johannes Zottele --- .../kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt | 44 ++++++++++++------- .../src/commonTest/proto/oneof.proto | 2 + .../protobuf/ModelToKotlinCommonGenerator.kt | 2 - 3 files changed, 31 insertions(+), 17 deletions(-) diff --git a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt index 6e233ec8f..50cf436f8 100644 --- a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt +++ b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt @@ -141,26 +141,40 @@ class ProtosTest { @Test fun testOneOf() { - val msg1 = OneOfMsg { - field = OneOfMsg.Field.Sint(23) + run { + val msg = OneOfMsg { + field = OneOfMsg.Field.Sint(23) + } + val decoded = encodeDecode(msg, OneOfMsgInternal.CODEC) + assertEquals(OneOfMsg.Field.Sint(23), decoded.field) + } + + run { + val msg = OneOfMsg { + field = OneOfMsg.Field.Fixed(21u) + } + val decoded = encodeDecode(msg, OneOfMsgInternal.CODEC) + assertEquals(OneOfMsg.Field.Fixed(21u), decoded.field) } - val decoded1 = encodeDecode(msg1, OneOfMsgInternal.CODEC) - assertEquals(OneOfMsg.Field.Sint(23), decoded1.field) - val msg2 = OneOfMsg { - field = OneOfMsg.Field.Fixed(21u) + run { + val msg = OneOfMsg { + field = OneOfMsg.Field.Other(Other { arg2 = "test" }) + } + val decoded = encodeDecode(msg, OneOfMsgInternal.CODEC) + assertIs(decoded.field) + assertNull((decoded.field as OneOfMsg.Field.Other).value.arg1) + assertEquals("test", (decoded.field as OneOfMsg.Field.Other).value.arg2) + assertNull((decoded.field as OneOfMsg.Field.Other).value.arg3) } - val decoded2 = encodeDecode(msg2, OneOfMsgInternal.CODEC) - assertEquals(OneOfMsg.Field.Fixed(21u), decoded2.field) - val msg3 = OneOfMsg { - field = OneOfMsg.Field.Other(Other { arg2 = "test" }) + run { + val msg = OneOfMsg { + field = OneOfMsg.Field.Enum(MyEnum.ONE_SECOND) + } + val decoded = encodeDecode(msg, OneOfMsgInternal.CODEC) + assertEquals(MyEnum.ONE, (decoded.field as OneOfMsg.Field.Enum).value) } - val decoded3 = encodeDecode(msg3, OneOfMsgInternal.CODEC) - assertIs(decoded3.field) - assertNull((decoded3.field as OneOfMsg.Field.Other).value.arg1) - assertEquals("test", (decoded3.field as OneOfMsg.Field.Other).value.arg2) - assertNull((decoded3.field as OneOfMsg.Field.Other).value.arg3) } @Test diff --git a/grpc/grpc-core/src/commonTest/proto/oneof.proto b/grpc/grpc-core/src/commonTest/proto/oneof.proto index cc4c3d217..09c7e5abf 100644 --- a/grpc/grpc-core/src/commonTest/proto/oneof.proto +++ b/grpc/grpc-core/src/commonTest/proto/oneof.proto @@ -1,9 +1,11 @@ import "sub_message.proto"; +import "enum.proto"; message OneOfMsg { oneof field { int32 sint = 2; fixed64 fixed = 3; test.submsg.Other other = 4; + kotlinx.rpc.grpc.test.MyEnum enum = 5; } } \ No newline at end of file diff --git a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt index 23ef8933f..8e732d448 100644 --- a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt +++ b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt @@ -219,8 +219,6 @@ class ModelToKotlinCommonGenerator( private fun CodeGenerator.generateCodecObject(declaration: MessageDeclaration) { val msgFqName = declaration.name.safeFullName() - val downCastErrorStr = - "\${value::class.simpleName} implements ${msgFqName}, which is prohibited." val sourceFqName = "kotlinx.io.Source" val bufferFqName = "kotlinx.io.Buffer" scope("object CODEC : kotlinx.rpc.grpc.codec.MessageCodec<$msgFqName>") { From d5b466f6ef33f450eff3dce25d1bb94ca4dfd413 Mon Sep 17 00:00:00 2001 From: Johannes Zottele Date: Wed, 6 Aug 2025 16:18:26 +0200 Subject: [PATCH 09/11] grpc-pb: Add oneOf message merge test Signed-off-by: Johannes Zottele --- .../kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt | 26 +++++++++++++++++++ .../protobuf/ModelToKotlinCommonGenerator.kt | 4 +-- 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt index 50cf436f8..3489cbe5d 100644 --- a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt +++ b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt @@ -8,6 +8,8 @@ import OneOfMsg import OneOfMsgInternal import Outer import OuterInternal +import asInternal +import encodeWith import invoke import kotlinx.io.Buffer import kotlinx.rpc.grpc.codec.MessageCodec @@ -177,6 +179,30 @@ class ProtosTest { } } + @Test + fun testOneOfMsgMerging() { + val part1 = OneOfMsg { + field = OneOfMsg.Field.Other(Other { arg2 = "arg2" }) + } + val part2 = OneOfMsg { + field = OneOfMsg.Field.Other(Other { arg1 = "arg1" }) + } + + val buffer = Buffer() + val encoder = WireEncoder(buffer) + part1.asInternal().encodeWith(encoder) + part2.asInternal().encodeWith(encoder) + encoder.flush() + + + val decoded = OneOfMsgInternal.CODEC.decode(buffer) + assertIs(decoded.field) + val decodedOther = (decoded.field as OneOfMsg.Field.Other).value + assertEquals("arg2", decodedOther.arg2) + assertEquals("arg1", decodedOther.arg1) + assertEquals(null, decodedOther.arg3) + } + @Test fun testOneOfLastWins() { // write two values on the oneOf field. diff --git a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt index 8e732d448..e6e1f68d5 100644 --- a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt +++ b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt @@ -265,7 +265,7 @@ class ModelToKotlinCommonGenerator( code("val tag = decoder.readTag() ?: break // EOF, we read the whole message") whenBlock { declaration.fields().forEach { (_, field) -> readMatchCase(field) } - whenCase("else") { code("TODO(\"Handle unknown fields\")") } + whenCase("else") { code("TODO(\"Handle unknown fields: \$tag\")") } } } ifBranch( @@ -394,7 +394,7 @@ class ModelToKotlinCommonGenerator( is FieldType.Message -> { val internalClassName = fieldType.dec.value.internalClassFullName() - code("decoder.readMessage($lvalue.asInternal(), $internalClassName::decodeWith)") + code("decoder.readMessage($lvalue.asInternal(), $internalClassName.CODEC::decodeWith)") } is FieldType.Map -> TODO() From 39403d007286ee2a128285fe49fceccf8d260ef9 Mon Sep 17 00:00:00 2001 From: Johannes Zottele Date: Wed, 6 Aug 2025 17:45:49 +0200 Subject: [PATCH 10/11] grpc-pb: Check required fields ins submessages Signed-off-by: Johannes Zottele --- .../kotlinx/rpc/grpc/pb/InternalMessage.kt | 9 +++--- .../kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt | 30 ++++++++++++++++--- .../src/commonTest/proto/oneof.proto | 8 +++++ .../src/commonTest/proto/repeated.proto | 6 ++++ .../protobuf/ModelToKotlinCommonGenerator.kt | 29 ++++++++++++++++++ 5 files changed, 74 insertions(+), 8 deletions(-) diff --git a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/InternalMessage.kt b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/InternalMessage.kt index 32383c2d2..d877144e2 100644 --- a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/InternalMessage.kt +++ b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/InternalMessage.kt @@ -17,29 +17,30 @@ public abstract class InternalMessage(fieldsWithPresence: Int) { public abstract val _size: Int } +@InternalRpcApi public class MsgFieldDelegate( private val presenceIdx: Int? = null, private val defaultProvider: (() -> T)? = null ) : ReadWriteProperty { private var valueSet = false - private var _value: T? = null + private var value: T? = null override operator fun getValue(thisRef: InternalMessage, property: KProperty<*>): T { if (!valueSet) { if (defaultProvider != null) { - _value = defaultProvider.invoke() + value = defaultProvider.invoke() valueSet = true } else { error("Property ${property.name} not initialized") } } - return _value as T + return value as T } override operator fun setValue(thisRef: InternalMessage, property: KProperty<*>, value: T) { presenceIdx?.let { thisRef.presenceMask[it] = true } - _value = value + this@MsgFieldDelegate.value = value valueSet = true } } \ No newline at end of file diff --git a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt index 3489cbe5d..7c73a2731 100644 --- a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt +++ b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt @@ -6,6 +6,7 @@ package kotlinx.rpc.grpc.pb import OneOfMsg import OneOfMsgInternal +import OneOfWithRequired import Outer import OuterInternal import asInternal @@ -87,10 +88,20 @@ class ProtosTest { } @Test - fun testPresenceCheckProto() { + fun testRepeatedWithRequiredSubField() { + assertFailsWith { + RepeatedWithRequired { + // we construct the message using the internal class, + // so it is not invoking the checkRequired method on construction + msgList = listOf(PresenceCheck { RequiredPresence = 2 }, PresenceCheckInternal()) + } + } + } + @Test + fun testPresenceCheckProto() { // Check a missing required field in a user-constructed message - assertFailsWith("PresenceCheck is missing required field: RequiredPresence") { + assertFailsWith { PresenceCheck {} } @@ -100,7 +111,7 @@ class ProtosTest { encoder.writeFloat(2, 1f) encoder.flush() - assertFailsWith("PresenceCheck is missing required field: RequiredPresence") { + assertFailsWith { PresenceCheckInternal.CODEC.decode(buffer) } } @@ -217,6 +228,17 @@ class ProtosTest { assertEquals(OneOfMsg.Field.Fixed(123u), decoded.field) } + @Test + fun testOneOfRequiredSubField() { + assertFailsWith { + OneOfWithRequired { + // we construct the message using the internal class, + // so it is not invoking the checkRequired method on construction + field = OneOfWithRequired.Field.Msg(PresenceCheckInternal()) + } + } + } + @Test fun testOneOfNull() { // write two values on the oneOf field. @@ -239,7 +261,7 @@ class ProtosTest { @Test fun testRecursiveReqNotSet() { - assertFailsWith("RecursiveReq is missing required field: rec") { + assertFailsWith { val msg = RecursiveReq { rec = RecursiveReq { rec = RecursiveReq { diff --git a/grpc/grpc-core/src/commonTest/proto/oneof.proto b/grpc/grpc-core/src/commonTest/proto/oneof.proto index 09c7e5abf..b55ddae16 100644 --- a/grpc/grpc-core/src/commonTest/proto/oneof.proto +++ b/grpc/grpc-core/src/commonTest/proto/oneof.proto @@ -1,5 +1,6 @@ import "sub_message.proto"; import "enum.proto"; +import "presence_check.proto"; message OneOfMsg { oneof field { @@ -8,4 +9,11 @@ message OneOfMsg { test.submsg.Other other = 4; kotlinx.rpc.grpc.test.MyEnum enum = 5; } +} + +message OneOfWithRequired { + oneof field { + int32 sint = 1; + kotlinx.rpc.grpc.test.common.PresenceCheck msg = 2; + } } \ No newline at end of file diff --git a/grpc/grpc-core/src/commonTest/proto/repeated.proto b/grpc/grpc-core/src/commonTest/proto/repeated.proto index fad5caf75..7451b33ac 100644 --- a/grpc/grpc-core/src/commonTest/proto/repeated.proto +++ b/grpc/grpc-core/src/commonTest/proto/repeated.proto @@ -1,5 +1,7 @@ syntax = "proto3"; +import "presence_check.proto"; + package kotlinx.rpc.grpc.test.common; message Repeated { @@ -13,4 +15,8 @@ message Repeated { message Other { int32 a = 1; } +} + +message RepeatedWithRequired { + repeated PresenceCheck msgList = 1; } \ No newline at end of file diff --git a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt index e6e1f68d5..3c731cdca 100644 --- a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt +++ b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt @@ -530,6 +530,35 @@ class ModelToKotlinCommonGenerator( code("${field.name}.asInternal().checkRequiredFields()") }) } + + // check submessages in oneofs + declaration.fields().filter { it.second.type is FieldType.OneOf }.forEach { (_, field) -> + val oneOfType = field.type as FieldType.OneOf + val messageVariants = oneOfType.dec.variants.filter { it.type is FieldType.Message } + if (messageVariants.isEmpty()) return@forEach + + scope("${field.name}?.also") { + whenBlock { + messageVariants.forEach { variant -> + val variantClassName = "${field.type.dec.name.safeFullName()}.${variant.name}" + whenCase("it is $variantClassName") { + code("it.value.asInternal().checkRequiredFields()") + } + } + } + } + } + + // check submessages in lists + declaration.fields().filter { it.second.type is FieldType.List }.forEach { (_, field) -> + val listType = field.type as FieldType.List + if (listType.value !is FieldType.Message) return@forEach + + scope("${field.name}.forEach") { + code("it.asInternal().checkRequiredFields()") + } + } + } private fun CodeGenerator.generateInternalComputeSize(declaration: MessageDeclaration) { From e2050736e71c1c40884afa02eafc8f1bdbac2912 Mon Sep 17 00:00:00 2001 From: Johannes Zottele Date: Wed, 6 Aug 2025 18:12:12 +0200 Subject: [PATCH 11/11] grpc-pb: Address PR comments Signed-off-by: Johannes Zottele --- grpc/grpc-core/src/commonTest/proto/nested.proto | 7 ------- .../kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt | 4 ++-- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/grpc/grpc-core/src/commonTest/proto/nested.proto b/grpc/grpc-core/src/commonTest/proto/nested.proto index 18e0bf6cd..d0d715ad4 100644 --- a/grpc/grpc-core/src/commonTest/proto/nested.proto +++ b/grpc/grpc-core/src/commonTest/proto/nested.proto @@ -1,10 +1,3 @@ -// Protocol Buffers - Google's data interchange format -// Copyright 2023 Google LLC. All rights reserved. -// -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file or at -// https://developers.google.com/open-source/licenses/bsd - syntax = "proto2"; package test.nested; diff --git a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt index 3c731cdca..fd37e8738 100644 --- a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt +++ b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt @@ -315,7 +315,7 @@ class ModelToKotlinCommonGenerator( lvalue = "field.value", beforeValueDecoding = { beforeValueDecoding() - scope("val field = ($lvalue as? $variantName) ?: $variantName(${variant.type.internalCtor()}).also") { + scope("val field = ($lvalue as? $variantName) ?: $variantName(${variant.type.internalConstructor()}).also") { // write the constructed oneof variant to the field code("$lvalue = it") } @@ -870,7 +870,7 @@ class ModelToKotlinCommonGenerator( } } - private fun FieldType.Message.internalCtor() = + private fun FieldType.Message.internalConstructor() = dec.value.internalClassFullName() + "()" private fun MessageDeclaration.internalClassFullName(): String {