Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions protobuf/protobuf-api/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ kotlin {

commonTest {
dependencies {
implementation(projects.protobuf.protobufWkt)
implementation(libs.kotlin.test)
implementation(libs.coroutines.test)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -446,4 +446,14 @@ class ProtosTest {
(decoded.oneOfWithGroup as WithGroups.OneOfWithGroup.Testgroup).value.value
)
}

@Test
fun testGroupAsStandalone() {
val standaloneGroup = WithGroups.FirstGroup {
value = 42u
}

val decoded = encodeDecode(standaloneGroup, marshallerOf<WithGroups.FirstGroup>())
assertEquals(standaloneGroup.value, decoded.value)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,9 @@ class ModelToProtobufKotlinCommonGenerator(
private fun CodeGenerator.generatePublicMessage(declaration: MessageDeclaration) {
if (!declaration.isUserFacing) return

val annotations = mutableListOf<ScopedFormattedString>()
if (!declaration.isGroup) {
annotations.add(
FqName.Annotations.GeneratedProtoMessage.scopedAnnotation()
)
}
val annotations = listOf(
FqName.Annotations.GeneratedProtoMessage.scopedAnnotation()
)

generatedMetadata.protoNamesList.add(declaration.name)

Expand Down Expand Up @@ -773,8 +770,6 @@ class ModelToProtobufKotlinCommonGenerator(

private fun CodeGenerator.generateMarshallerObject(declaration: MessageDeclaration) {
if (!declaration.isUserFacing) return
// the MARSHALLER object is not necessary for groups, as they are inlined messages
if (declaration.isGroup) return

clazz(
name = declaration.marshallerObjectName.simpleName,
Expand Down Expand Up @@ -808,8 +803,11 @@ class ModelToProtobufKotlinCommonGenerator(
scope("%T(source).use".scoped(FqName.RpcClasses.WireDecoder)) {
code("val msg = %T()".scoped(declaration.internalClassName))
scope("checkForPlatformDecodeException".scoped(), nlAfterClosed = false) {
// if the message declaration is a group, we must pass null as the
// startGroup tag to indicate that this message is decoded as standalone (like a normal message)
val groupExtraArg = if (declaration.isGroup) ", null" else ""
code(
"%T.decodeWith(msg, it, config as? %T)".scoped(
"%T.decodeWith(msg, it, config as? %T$groupExtraArg)".scoped(
declaration.internalClassName,
FqName.RpcClasses.ProtobufConfig,
)
Expand All @@ -828,7 +826,6 @@ class ModelToProtobufKotlinCommonGenerator(

private fun CodeGenerator.generateDescriptorObject(declaration: MessageDeclaration) {
if (!declaration.isUserFacing) return
if (declaration.isGroup) return

clazz(
name = "DESCRIPTOR",
Expand Down Expand Up @@ -881,8 +878,12 @@ class ModelToProtobufKotlinCommonGenerator(
.scoped(declaration.internalClassName, FqName.RpcClasses.WireDecoder, FqName.RpcClasses.ProtobufConfig)

if (declaration.isGroup) {
// if the message is a group message, the decoder accepts an optional startGroup tag, which indicates
// that the decoding of the message must end with an END_GROUP tag of the same fieldNr.
// if the startGroup tag is null, we treat it like a normal message.
// the argument is not default null, to avoid that we forget to set it when changing the generator.
args = args.merge(FqName.RpcClasses.KTag.scoped()) { start, kTag ->
"$start, startGroup: $kTag"
"$start, startGroup: $kTag?"
}
}

Expand All @@ -895,14 +896,21 @@ class ModelToProtobufKotlinCommonGenerator(
) {
whileBlock("true".scoped()) {
if (declaration.isGroup) {
code(
"val tag = decoder.readTag() ?: throw %T(\"Missing END_GROUP tag for field: \${startGroup.fieldNr}.\")"
.scoped(FqName.RpcClasses.ProtobufDecodingException)
)
scope("val tag = decoder.readTag() ?: run".scoped()) {
// if the startGroup tag is set, we decode the message as an inline group message,
// so in case of the tag being null, we know that the payload is malformed
scope("startGroup?.let".scoped()) {
code("throw %T(\"Missing END_GROUP tag for field: \${startGroup.fieldNr}.\")"
.scoped(FqName.RpcClasses.ProtobufDecodingException))
}
// if the startGroup is null, we decode the message like a normal non-group message,
// so we stop when the tag is null
code("return".scoped())
}
ifBranch(condition = "tag.wireType == %T".scoped(FqName.RpcClasses.WireType_END_GROUP), ifBlock = {
ifBranch(condition = "tag.fieldNr != startGroup.fieldNr".scoped(), ifBlock = {
ifBranch(condition = "tag.fieldNr != startGroup?.fieldNr".scoped(), ifBlock = {
code(
"throw %T(\"Wrong END_GROUP tag. Expected \${startGroup.fieldNr}, got \${tag.fieldNr}.\")"
"throw %T(\"Wrong END_GROUP tag. Expected \${startGroup?.fieldNr}, got \${tag.fieldNr}.\")"
.scoped(FqName.RpcClasses.ProtobufDecodingException)
)
})
Expand Down
Loading