diff --git a/protobuf/protobuf-api/build.gradle.kts b/protobuf/protobuf-api/build.gradle.kts index ac453aee0..8d5221f1b 100644 --- a/protobuf/protobuf-api/build.gradle.kts +++ b/protobuf/protobuf-api/build.gradle.kts @@ -32,6 +32,7 @@ kotlin { commonTest { dependencies { + implementation(projects.protobuf.protobufWkt) implementation(libs.kotlin.test) implementation(libs.coroutines.test) } diff --git a/protobuf/protobuf-api/src/commonTest/kotlin/kotlinx/rpc/protobuf/test/ProtosTest.kt b/protobuf/protobuf-api/src/commonTest/kotlin/kotlinx/rpc/protobuf/test/ProtosTest.kt index bddf90ff9..fff541638 100644 --- a/protobuf/protobuf-api/src/commonTest/kotlin/kotlinx/rpc/protobuf/test/ProtosTest.kt +++ b/protobuf/protobuf-api/src/commonTest/kotlin/kotlinx/rpc/protobuf/test/ProtosTest.kt @@ -446,4 +446,14 @@ class ProtosTest { (decoded.oneOfWithGroup as WithGroups.OneOfWithGroup.Testgroup).value.value ) } + + @Test + fun testGroupAsStandalone() { + val standaloneGroup = WithGroups.FirstGroup { + value = 42u + } + + val decoded = encodeDecode(standaloneGroup, marshallerOf()) + assertEquals(standaloneGroup.value, decoded.value) + } } diff --git a/protoc-gen/protobuf/src/main/kotlin/kotlinx/rpc/protoc/gen/ModelToProtobufKotlinCommonGenerator.kt b/protoc-gen/protobuf/src/main/kotlin/kotlinx/rpc/protoc/gen/ModelToProtobufKotlinCommonGenerator.kt index 9fa64f94e..39e378d1a 100644 --- a/protoc-gen/protobuf/src/main/kotlin/kotlinx/rpc/protoc/gen/ModelToProtobufKotlinCommonGenerator.kt +++ b/protoc-gen/protobuf/src/main/kotlin/kotlinx/rpc/protoc/gen/ModelToProtobufKotlinCommonGenerator.kt @@ -101,12 +101,9 @@ class ModelToProtobufKotlinCommonGenerator( private fun CodeGenerator.generatePublicMessage(declaration: MessageDeclaration) { if (!declaration.isUserFacing) return - val annotations = mutableListOf() - if (!declaration.isGroup) { - annotations.add( - FqName.Annotations.GeneratedProtoMessage.scopedAnnotation() - ) - } + val annotations = listOf( + FqName.Annotations.GeneratedProtoMessage.scopedAnnotation() + ) generatedMetadata.protoNamesList.add(declaration.name) @@ -773,8 +770,6 @@ class ModelToProtobufKotlinCommonGenerator( private fun CodeGenerator.generateMarshallerObject(declaration: MessageDeclaration) { if (!declaration.isUserFacing) return - // the MARSHALLER object is not necessary for groups, as they are inlined messages - if (declaration.isGroup) return clazz( name = declaration.marshallerObjectName.simpleName, @@ -808,8 +803,11 @@ class ModelToProtobufKotlinCommonGenerator( scope("%T(source).use".scoped(FqName.RpcClasses.WireDecoder)) { code("val msg = %T()".scoped(declaration.internalClassName)) scope("checkForPlatformDecodeException".scoped(), nlAfterClosed = false) { + // if the message declaration is a group, we must pass null as the + // startGroup tag to indicate that this message is decoded as standalone (like a normal message) + val groupExtraArg = if (declaration.isGroup) ", null" else "" code( - "%T.decodeWith(msg, it, config as? %T)".scoped( + "%T.decodeWith(msg, it, config as? %T$groupExtraArg)".scoped( declaration.internalClassName, FqName.RpcClasses.ProtobufConfig, ) @@ -828,7 +826,6 @@ class ModelToProtobufKotlinCommonGenerator( private fun CodeGenerator.generateDescriptorObject(declaration: MessageDeclaration) { if (!declaration.isUserFacing) return - if (declaration.isGroup) return clazz( name = "DESCRIPTOR", @@ -881,8 +878,12 @@ class ModelToProtobufKotlinCommonGenerator( .scoped(declaration.internalClassName, FqName.RpcClasses.WireDecoder, FqName.RpcClasses.ProtobufConfig) if (declaration.isGroup) { + // if the message is a group message, the decoder accepts an optional startGroup tag, which indicates + // that the decoding of the message must end with an END_GROUP tag of the same fieldNr. + // if the startGroup tag is null, we treat it like a normal message. + // the argument is not default null, to avoid that we forget to set it when changing the generator. args = args.merge(FqName.RpcClasses.KTag.scoped()) { start, kTag -> - "$start, startGroup: $kTag" + "$start, startGroup: $kTag?" } } @@ -895,14 +896,21 @@ class ModelToProtobufKotlinCommonGenerator( ) { whileBlock("true".scoped()) { if (declaration.isGroup) { - code( - "val tag = decoder.readTag() ?: throw %T(\"Missing END_GROUP tag for field: \${startGroup.fieldNr}.\")" - .scoped(FqName.RpcClasses.ProtobufDecodingException) - ) + scope("val tag = decoder.readTag() ?: run".scoped()) { + // if the startGroup tag is set, we decode the message as an inline group message, + // so in case of the tag being null, we know that the payload is malformed + scope("startGroup?.let".scoped()) { + code("throw %T(\"Missing END_GROUP tag for field: \${startGroup.fieldNr}.\")" + .scoped(FqName.RpcClasses.ProtobufDecodingException)) + } + // if the startGroup is null, we decode the message like a normal non-group message, + // so we stop when the tag is null + code("return".scoped()) + } ifBranch(condition = "tag.wireType == %T".scoped(FqName.RpcClasses.WireType_END_GROUP), ifBlock = { - ifBranch(condition = "tag.fieldNr != startGroup.fieldNr".scoped(), ifBlock = { + ifBranch(condition = "tag.fieldNr != startGroup?.fieldNr".scoped(), ifBlock = { code( - "throw %T(\"Wrong END_GROUP tag. Expected \${startGroup.fieldNr}, got \${tag.fieldNr}.\")" + "throw %T(\"Wrong END_GROUP tag. Expected \${startGroup?.fieldNr}, got \${tag.fieldNr}.\")" .scoped(FqName.RpcClasses.ProtobufDecodingException) ) })