diff --git a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt index 7c73a2731..001e0a2a8 100644 --- a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt +++ b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt @@ -14,11 +14,8 @@ import encodeWith import invoke import kotlinx.io.Buffer import kotlinx.rpc.grpc.codec.MessageCodec -import kotlinx.rpc.grpc.test.MyEnum -import kotlinx.rpc.grpc.test.UsingEnum -import kotlinx.rpc.grpc.test.UsingEnumInternal +import kotlinx.rpc.grpc.test.* import kotlinx.rpc.grpc.test.common.* -import kotlinx.rpc.grpc.test.invoke import test.nested.* import test.recursive.Recursive import test.recursive.RecursiveInternal @@ -31,7 +28,7 @@ class ProtosTest { private fun encodeDecode( msg: M, - codec: MessageCodec + codec: MessageCodec, ): M { val source = codec.encode(msg) return codec.decode(source) @@ -357,4 +354,38 @@ class ProtosTest { assertEquals("fourth", decoded.other.arg3) } + + @Test + fun testMap() { + val msg = TestMap { + primitives = mapOf("one" to 1, "two" to 2, "three" to 3) + messages = mapOf( + 1 to PresenceCheck { RequiredPresence = 1 }, + 2 to PresenceCheck { RequiredPresence = 2; OptionalPresence = 3F } + ) + } + + val decoded = encodeDecode(msg, TestMapInternal.CODEC) + assertEquals(msg.primitives, decoded.primitives) + assertEquals(msg.messages.size, decoded.messages.size) + for ((key, value) in msg.messages) { + assertEquals(value.RequiredPresence, decoded.messages[key]!!.RequiredPresence) + assertEquals(value.OptionalPresence, decoded.messages[key]!!.OptionalPresence) + } + } + + @Test + fun testMapRequiredSubField() { + // we use the internal constructor to avoid a "missing required field" error during object construction + val missingRequiredMessage = PresenceCheckInternal() + + assertFailsWith { + val msg = TestMap { + messages = mapOf( + 2 to missingRequiredMessage + ) + } + } + } + } diff --git a/grpc/grpc-core/src/commonTest/proto/exclude/test_map.proto b/grpc/grpc-core/src/commonTest/proto/test_map.proto similarity index 51% rename from grpc/grpc-core/src/commonTest/proto/exclude/test_map.proto rename to grpc/grpc-core/src/commonTest/proto/test_map.proto index cbb3f8e61..816c292e8 100644 --- a/grpc/grpc-core/src/commonTest/proto/exclude/test_map.proto +++ b/grpc/grpc-core/src/commonTest/proto/test_map.proto @@ -2,9 +2,9 @@ syntax = "proto3"; package kotlinx.rpc.grpc.test; -import "reference_package.proto"; +import "presence_check.proto"; message TestMap { map primitives = 1; - map references = 2; -} + map messages = 2; +} \ No newline at end of file diff --git a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/CodeGenerator.kt b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/CodeGenerator.kt index 0dd8db665..cd8ddfbb5 100644 --- a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/CodeGenerator.kt +++ b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/CodeGenerator.kt @@ -72,10 +72,11 @@ open class CodeGenerator( suffix: String = "", nlAfterClosed: Boolean = true, openingBracket: Boolean = true, + paramDecl: String = "", block: (CodeGenerator.() -> Unit)? = null, ) { addLine(prefix) - scopeWithSuffix(suffix, openingBracket, nlAfterClosed, block) + scopeWithSuffix(suffix, openingBracket, nlAfterClosed, paramDecl, block) } internal fun ifBranch( @@ -122,6 +123,7 @@ open class CodeGenerator( suffix: String = "", openingBracket: Boolean = true, nlAfterClosed: Boolean = true, + paramDeclaration: String = "", block: (CodeGenerator.() -> Unit)? = null, ) { if (block == null) { @@ -139,7 +141,7 @@ open class CodeGenerator( } if (openingBracket) { - append(" {") + append(" { $paramDeclaration") } newLine() append(nested.build().trimEnd()) diff --git a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt index fd37e8738..598701b7e 100644 --- a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt +++ b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt @@ -132,6 +132,8 @@ class ModelToKotlinCommonGenerator( @Suppress("detekt.CyclomaticComplexMethod") private fun CodeGenerator.generatePublicMessage(declaration: MessageDeclaration) { + if (!declaration.isUserFacing) return + clazz( name = declaration.name.simpleName, declarationType = DeclarationType.Interface, @@ -162,20 +164,32 @@ class ModelToKotlinCommonGenerator( @Suppress("detekt.CyclomaticComplexMethod") private fun CodeGenerator.generateInternalMessage(declaration: MessageDeclaration) { val internalClassName = declaration.internalClassName() + + val annotations = buildList { + add("@$INTERNAL_RPC_API_ANNO") + if (declaration.isUserFacing) { + add("@$WITH_CODEC_ANNO($internalClassName.CODEC::class)") + } + } + val superTypes = buildList { + if (declaration.isUserFacing) { + add(declaration.name.safeFullName()) + } + add("$PB_PKG.InternalMessage(fieldsWithPresence = ${declaration.presenceMaskSize})") + } + clazz( name = internalClassName, - annotations = listOf("@$INTERNAL_RPC_API_ANNO", "@$WITH_CODEC_ANNO($internalClassName.CODEC::class)"), + annotations = annotations, declarationType = DeclarationType.Class, - superTypes = listOf( - declaration.name.safeFullName(), - "$PB_PKG.InternalMessage(fieldsWithPresence = ${declaration.presenceMaskSize})" - ), + superTypes = superTypes, ) { generatePresenceIndicesObject(declaration) code("override val _size: Int by lazy { computeSize() }") + val override = if (declaration.isUserFacing) "override " else "" declaration.fields().forEach { (fieldDeclaration, field) -> val value = when { field.nullable -> { @@ -191,7 +205,7 @@ class ModelToKotlinCommonGenerator( } } - code("override var $fieldDeclaration $value") + code("$override var $fieldDeclaration $value") newLine() } @@ -200,6 +214,9 @@ class ModelToKotlinCommonGenerator( } generateCodecObject(declaration) + + // required for decodeWith extension + code("companion object") } } @@ -218,6 +235,8 @@ class ModelToKotlinCommonGenerator( } private fun CodeGenerator.generateCodecObject(declaration: MessageDeclaration) { + if (!declaration.isUserFacing) return + val msgFqName = declaration.name.safeFullName() val sourceFqName = "kotlinx.io.Source" val bufferFqName = "kotlinx.io.Buffer" @@ -233,7 +252,7 @@ class ModelToKotlinCommonGenerator( function("decode", modifiers = "override", args = "stream: $sourceFqName", returnType = msgFqName) { scope("$PB_PKG.WireDecoder(stream as $bufferFqName).use") { code("val msg = ${declaration.internalClassFullName()}()") - code("${declaration.internalClassFullName()}.CODEC.decodeWith(msg, it)") + code("${declaration.internalClassFullName()}.decodeWith(msg, it)") code("msg.checkRequiredFields()") code("return msg") } @@ -241,17 +260,21 @@ class ModelToKotlinCommonGenerator( } } - private fun CodeGenerator.generateMessageConstructor(declaration: MessageDeclaration) = function( - name = "invoke", - modifiers = "operator", - args = "body: ${declaration.internalClassFullName()}.() -> Unit", - contextReceiver = "${declaration.name.safeFullName()}.Companion", - returnType = declaration.name.safeFullName(), - ) { - code("val msg = ${declaration.internalClassFullName()}().apply(body)") - // check if the user set all required fields - code("msg.checkRequiredFields()") - code("return msg") + private fun CodeGenerator.generateMessageConstructor(declaration: MessageDeclaration) { + if (!declaration.isUserFacing) return + + function( + name = "invoke", + modifiers = "operator", + args = "body: ${declaration.internalClassFullName()}.() -> Unit", + contextReceiver = "${declaration.name.safeFullName()}.Companion", + returnType = declaration.name.safeFullName(), + ) { + code("val msg = ${declaration.internalClassFullName()}().apply(body)") + // check if the user set all required fields + code("msg.checkRequiredFields()") + code("return msg") + } } private fun CodeGenerator.generateMessageDecoder(declaration: MessageDeclaration) = function( @@ -259,7 +282,7 @@ class ModelToKotlinCommonGenerator( modifiers = "internal", args = "msg: ${declaration.internalClassFullName()}, decoder: $PB_PKG.WireDecoder", annotations = listOf("@$INTERNAL_RPC_API_ANNO"), - contextReceiver = "${declaration.internalClassFullName()}.CODEC" + contextReceiver = "${declaration.internalClassFullName()}.Companion" ) { whileBlock("!decoder.hadError()") { code("val tag = decoder.readTag() ?: break // EOF, we read the whole message") @@ -273,7 +296,7 @@ class ModelToKotlinCommonGenerator( ifBlock = { code("error(\"Error during decoding of ${declaration.name.simpleName}\")") } ) - // TODO: Make a lists immutable + // TODO: Make lists and maps immutable (KRPC-190) } private fun CodeGenerator.readMatchCase( @@ -345,7 +368,10 @@ class ModelToKotlinCommonGenerator( } } - is FieldType.Map -> TODO() + is FieldType.Map -> whenCase("tag.fieldNr == ${field.number} && tag.wireType == $PB_PKG.WireType.LENGTH_DELIMITED") { + beforeValueDecoding() + generateDecodeFieldValue(fieldType, lvalue, wrapperCtor = wrapperCtor) + } } } @@ -353,7 +379,7 @@ class ModelToKotlinCommonGenerator( fieldType: FieldType, lvalue: String, isPacked: Boolean = false, - wrapperCtor: (String) -> String = { it } + wrapperCtor: (String) -> String = { it }, ) { when (fieldType) { is FieldType.IntegralType -> { @@ -372,7 +398,7 @@ class ModelToKotlinCommonGenerator( else -> generateDecodeFieldValue(fieldType.value, "val elem", wrapperCtor = wrapperCtor) } - code("($lvalue as ArrayList).add(elem)") + code("($lvalue as MutableList).add(elem)") } is FieldType.Enum -> { @@ -394,10 +420,21 @@ class ModelToKotlinCommonGenerator( is FieldType.Message -> { val internalClassName = fieldType.dec.value.internalClassFullName() - code("decoder.readMessage($lvalue.asInternal(), $internalClassName.CODEC::decodeWith)") + code("decoder.readMessage($lvalue.asInternal(), $internalClassName::decodeWith)") } - is FieldType.Map -> TODO() + is FieldType.Map -> { + val entryClassName = fieldType.entry.dec.internalClassFullName() + scope("with($entryClassName())") { + generateDecodeFieldValue( + fieldType = FieldType.Message(lazy { fieldType.entry.dec }), + lvalue = "this", + isPacked = false, + wrapperCtor = wrapperCtor + ) + code("($lvalue as MutableMap)[key] = value") + } + } } } @@ -417,40 +454,56 @@ class ModelToKotlinCommonGenerator( val fieldName = field.name if (field.nullable) { scope("$fieldName?.also") { - writeFieldValue(field, "it") + generateEncodeFieldValue(field, "it") } - } else if (!field.dec.hasPresence()) { - ifBranch(condition = field.notDefaultCheck(), ifBlock = { - writeFieldValue(field, field.name) + } else if (field.dec.hasPresence()) { + ifBranch(condition = "presenceMask[${field.presenceIdx}]", ifBlock = { + generateEncodeFieldValue(field, field.name) }) } else { - ifBranch(condition = "presenceMask[${field.presenceIdx}]", ifBlock = { - writeFieldValue(field, field.name) + ifBranch(condition = field.notDefaultCheck(), ifBlock = { + generateEncodeFieldValue(field, field.name) }) } } } - private fun CodeGenerator.writeFieldValue(field: FieldDeclaration, valueVar: String) { - var encFunc = field.type.decodeEncodeFuncName() - val number = field.number - when (val fieldType = field.type) { + private fun CodeGenerator.generateEncodeFieldValue( + field: FieldDeclaration, + valueVar: String, + ) { + generateEncodeFieldValue( + valueVar, field.type, number = field.number, + isPacked = field.dec.isPacked, + packedWithFixedSize = field.packedFixedSize + ) + } + + private fun CodeGenerator.generateEncodeFieldValue( + valueVar: String, + type: FieldType, + number: Int, + isPacked: Boolean, + packedWithFixedSize: Boolean, + ) { + var encFunc = type.decodeEncodeFuncName() + when (val fieldType = type) { is FieldType.IntegralType -> code("encoder.write${encFunc!!}(fieldNr = $number, value = $valueVar)") is FieldType.List -> { encFunc = fieldType.value.decodeEncodeFuncName() when { - field.dec.isPacked && field.packedFixedSize -> + isPacked && packedWithFixedSize -> code("encoder.writePacked${encFunc!!}(fieldNr = $number, value = $valueVar)") - field.dec.isPacked && !field.packedFixedSize -> + isPacked && !packedWithFixedSize -> code( "encoder.writePacked${encFunc!!}(fieldNr = $number, value = $valueVar, fieldSize = ${ - field.type.valueSizeCall(valueVar, number, true) + type.valueSizeCall(valueVar, number, true) })" ) fieldType.value is FieldType.Message -> scope("$valueVar.forEach") { - code("encoder.writeMessage(fieldNr = ${field.number}, value = it.asInternal()) { encodeWith(it) }") + code("encoder.writeMessage(fieldNr = ${number}, value = it.asInternal()) { encodeWith(it) }") } else -> { @@ -467,14 +520,25 @@ class ModelToKotlinCommonGenerator( is FieldType.OneOf -> whenBlock("val value = $valueVar") { fieldType.dec.variants.forEach { variant -> whenCase("is ${fieldType.dec.name.safeFullName()}.${variant.name}") { - writeFieldValue(variant, "value.value") + generateEncodeFieldValue(variant, "value.value") } } } - is FieldType.Map -> TODO() + is FieldType.Map -> { + scope("$valueVar.forEach", paramDecl = "kEntry ->") { + generateMapConstruction(fieldType, "kEntry.key", "kEntry.value") + scope(".also", paramDecl = "entry ->") { + generateEncodeFieldValue( + valueVar = "entry", + type = FieldType.Message(lazy { fieldType.entry.dec }), + number = number, isPacked = false, packedWithFixedSize = false + ) + } + } + } - is FieldType.Message -> code("encoder.writeMessage(fieldNr = ${field.number}, value = $valueVar.asInternal()) { encodeWith(it) }") + is FieldType.Message -> code("encoder.writeMessage(fieldNr = ${number}, value = $valueVar.asInternal()) { encodeWith(it) }") } } @@ -512,10 +576,9 @@ class ModelToKotlinCommonGenerator( contextReceiver = declaration.internalClassFullName(), ) { val requiredFields = declaration.actualFields.filter { it.dec.isRequired } - val submessages = declaration.actualFields.filter { it.type is FieldType.Message } - if (submessages.isEmpty() && requiredFields.isEmpty()) { - code("// no fields to check") + if (requiredFields.isEmpty()) { + code("// no required fields to check") } requiredFields.forEach { field -> @@ -525,14 +588,14 @@ class ModelToKotlinCommonGenerator( } // check submessages - submessages.forEach { field -> + declaration.actualFields.filter { it.type is FieldType.Message }.forEach { field -> ifBranch(condition = "presenceMask[${field.presenceIdx}]", ifBlock = { code("${field.name}.asInternal().checkRequiredFields()") }) } // check submessages in oneofs - declaration.fields().filter { it.second.type is FieldType.OneOf }.forEach { (_, field) -> + declaration.actualFields.filter { it.type is FieldType.OneOf }.forEach { field -> val oneOfType = field.type as FieldType.OneOf val messageVariants = oneOfType.dec.variants.filter { it.type is FieldType.Message } if (messageVariants.isEmpty()) return@forEach @@ -550,7 +613,7 @@ class ModelToKotlinCommonGenerator( } // check submessages in lists - declaration.fields().filter { it.second.type is FieldType.List }.forEach { (_, field) -> + declaration.actualFields.filter { it.type is FieldType.List }.forEach { field -> val listType = field.type as FieldType.List if (listType.value !is FieldType.Message) return@forEach @@ -559,6 +622,17 @@ class ModelToKotlinCommonGenerator( } } + // check submessage in maps + declaration.actualFields.filter { it.type is FieldType.Map }.forEach { field -> + val mapType = field.type as FieldType.Map + // we only have to check the value, as the key cannot be a message + if (mapType.entry.value !is FieldType.Message) return@forEach + + scope("${field.name}.values.forEach") { + code("it.asInternal().checkRequiredFields()") + } + } + } private fun CodeGenerator.generateInternalComputeSize(declaration: MessageDeclaration) { @@ -590,14 +664,23 @@ class ModelToKotlinCommonGenerator( } private fun CodeGenerator.generateInternalCastExtension(declaration: MessageDeclaration) { + val internalClassName = declaration.internalClassFullName() + val ctxReceiver = if (declaration.isUserFacing) declaration.name.safeFullName() else internalClassName + + // we generate the asInternal extension even for non-user-facing message classes (map entry) + // to avoid edge-cases when generating other code that uses the asInternal() extension. function( "asInternal", modifiers = "internal", annotations = listOf("@$INTERNAL_RPC_API_ANNO"), - contextReceiver = declaration.name.safeFullName(), - returnType = declaration.internalClassFullName(), + contextReceiver = ctxReceiver, + returnType = internalClassName, ) { - code("return this as? ${declaration.internalClassFullName()} ?: error(\"Message \${this::class.simpleName} is a non-internal message type.\")") + if (ctxReceiver == internalClassName) { + code("return this") + } else { + code("return this as? $internalClassName ?: error(\"Message \${this::class.simpleName} is a non-internal message type.\")") + } } } @@ -615,9 +698,16 @@ class ModelToKotlinCommonGenerator( is FieldType.Message, FieldType.IntegralType.STRING, - FieldType.IntegralType.BYTES -> code("result += $valueSize.let { $tagSize + ${int32SizeCall("it")} + it }") + FieldType.IntegralType.BYTES, + -> code("result += $valueSize.let { $tagSize + ${int32SizeCall("it")} + it }") + + is FieldType.Map -> { + scope("result += ${field.name}.entries.sumOf", paramDecl = "kEntry ->") { + generateMapConstruction(field.type, "kEntry.key", "kEntry.value") + code("._size") + } + } - is FieldType.Map -> TODO() is FieldType.OneOf -> whenBlock("val value = $variable") { field.type.dec.variants.forEach { variant -> val variantName = "${field.type.dec.name.safeFullName()}.${variant.name}" @@ -640,7 +730,16 @@ class ModelToKotlinCommonGenerator( FieldType.IntegralType.SINT32, FieldType.IntegralType.SINT64, FieldType.IntegralType.SFIXED32, - FieldType.IntegralType.SFIXED64 -> code("result += ($tagSize + $valueSize)") + FieldType.IntegralType.SFIXED64, + -> code("result += ($tagSize + $valueSize)") + } + } + + private fun CodeGenerator.generateMapConstruction(map: FieldType.Map, keyVar: String, valueVar: String) { + val entryClass = map.entry.dec.internalClassFullName() + scope("$entryClass().apply", nlAfterClosed = false) { + code("key = $keyVar") + code("value = $valueVar") } } @@ -684,11 +783,11 @@ class ModelToKotlinCommonGenerator( } is FieldType.List -> "$name.isNotEmpty()" - is FieldType.Message -> "" + is FieldType.Message -> error("Message fields should not be checked for default values.") is FieldType.Enum -> "${fieldType.defaultValue} != $name" - else -> TODO("Field: $name, type: $fieldType") + else -> "$name.isNotEmpty()" } } @@ -744,7 +843,7 @@ class ModelToKotlinCommonGenerator( } is FieldType.Map -> { - val entry by type.entry + val entry = type.entry val fqKey = when (val key = entry.key) { is FieldType.Message -> key.dec.value.name diff --git a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/codeRequestToModel.kt b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/codeRequestToModel.kt index d9f013944..6f85829b8 100644 --- a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/codeRequestToModel.kt +++ b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/codeRequestToModel.kt @@ -257,8 +257,9 @@ private fun Descriptors.FieldDescriptor.modelType(): FieldType { if (isMapField) { val keyType = messageType.findFieldByName("key").modelType() val valType = messageType.findFieldByName("value").modelType() - val mapEntry = FieldType.Map.Entry(keyType, valType) - return FieldType.Map(lazy { mapEntry }) + val mapEntryDec = messageType.toModel() + val mapEntry = FieldType.Map.Entry(mapEntryDec, keyType, valType) + return FieldType.Map(mapEntry) } if (isRepeated) { diff --git a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/FieldType.kt b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/FieldType.kt index d1eda9206..c2de6b18d 100644 --- a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/FieldType.kt +++ b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/FieldType.kt @@ -20,16 +20,16 @@ sealed interface FieldType { val isPackable: Boolean get() = false data class List(val value: FieldType) : FieldType { - override val defaultValue: String = "arrayListOf()" + override val defaultValue: String = "mutableListOf()" override val wireType: WireType = value.wireType override val isPackable: Boolean = value.isPackable } - data class Map(val entry: Lazy) : FieldType { - override val defaultValue: String = "emptyMap()" + data class Map(val entry: Entry) : FieldType { + override val defaultValue: String = "mutableMapOf()" override val wireType: WireType = WireType.LENGTH_DELIMITED - data class Entry(val key: FieldType, val value: FieldType) + data class Entry(val dec: MessageDeclaration, val key: FieldType, val value: FieldType) } data class Enum(val dec: EnumDeclaration) : FieldType { diff --git a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/model.kt b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/model.kt index 0ad688e3e..dfe05671b 100644 --- a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/model.kt +++ b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/model.kt @@ -30,7 +30,10 @@ data class MessageDeclaration( val nestedDeclarations: List, val doc: String?, val dec: Descriptors.Descriptor, -) +) { + val isMapEntry = dec.options.mapEntry + val isUserFacing = !isMapEntry +} data class EnumDeclaration( val name: FqName,