Skip to content

Commit 635832d

Browse files
committed
grpc-pb: Support OneOf
Signed-off-by: Johannes Zottele <[email protected]>
1 parent 9bdcdfe commit 635832d

File tree

6 files changed

+127
-45
lines changed

6 files changed

+127
-45
lines changed

grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44

55
package kotlinx.rpc.grpc.pb
66

7+
import OneOfMsg
8+
import OneOfMsgInternal
9+
import invoke
710
import kotlinx.io.Buffer
811
import kotlinx.rpc.grpc.internal.MessageCodec
912
import kotlinx.rpc.grpc.test.Enum
@@ -14,6 +17,7 @@ import kotlinx.rpc.grpc.test.invoke
1417
import kotlin.test.Test
1518
import kotlin.test.assertEquals
1619
import kotlin.test.assertFailsWith
20+
import kotlin.test.assertNull
1721

1822
class ProtosTest {
1923

@@ -124,4 +128,42 @@ class ProtosTest {
124128
assertEquals(Enum.ZERO, decoded.enum)
125129
}
126130

131+
@Test
132+
fun testOneOf() {
133+
val msg1 = OneOfMsg {
134+
field = OneOfMsg.Field.Sint(23)
135+
}
136+
val decoded1 = decodeEncode(msg1, OneOfMsgInternal.CODEC)
137+
assertEquals(OneOfMsg.Field.Sint(23), decoded1.field)
138+
139+
val msg2 = OneOfMsg {
140+
field = OneOfMsg.Field.Fixed(21u)
141+
}
142+
val decoded2 = decodeEncode(msg2, OneOfMsgInternal.CODEC)
143+
assertEquals(OneOfMsg.Field.Fixed(21u), decoded2.field)
144+
}
145+
146+
@Test
147+
fun testOneOfLastWins() {
148+
// write two values on the oneOf field.
149+
// the second value must be the one stored during decoding.
150+
val buffer = Buffer()
151+
val encoder = WireEncoder(buffer)
152+
encoder.writeInt32(2, 99)
153+
encoder.writeFixed64(3, 123u)
154+
encoder.flush()
155+
156+
val decoded = OneOfMsgInternal.CODEC.decode(buffer)
157+
assertEquals(OneOfMsg.Field.Fixed(123u), decoded.field)
158+
}
159+
160+
@Test
161+
fun testOneOfNull() {
162+
// write two values on the oneOf field.
163+
// the second value must be the one stored during decoding.
164+
val buffer = Buffer()
165+
val decoded = OneOfMsgInternal.CODEC.decode(buffer)
166+
assertNull(decoded.field)
167+
}
168+
127169
}
Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
message OneOfMsg {
2-
3-
oneof f {
4-
int32 i = 2;
5-
fixed64 T = 3;
2+
oneof field {
3+
int32 sint = 2;
4+
fixed64 fixed = 3;
65
}
7-
86
}

protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt

Lines changed: 63 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@ class ModelToKotlinCommonGenerator(
7878

7979
generateInternalDeclaredEntities(this@generateInternalKotlinFile)
8080

81+
import("kotlinx.rpc.internal.utils.*")
82+
import("kotlinx.coroutines.flow.*")
83+
8184
additionalInternalImports.forEach {
8285
import(it)
8386
}
@@ -279,12 +282,15 @@ class ModelToKotlinCommonGenerator(
279282
code("return msg")
280283
}
281284

282-
private fun CodeGenerator.readMatchCase(field: FieldDeclaration) {
283-
val encFuncName = field.type.decodeEncodeFuncName()
284-
val assignment = "msg.${field.name} ="
285+
private fun CodeGenerator.readMatchCase(
286+
field: FieldDeclaration,
287+
assignment: String = "msg.${field.name} =",
288+
wrapperCtor: (String) -> String = { it }
289+
) {
285290
when (val fieldType = field.type) {
286291
is FieldType.IntegralType -> whenCase("tag.fieldNr == ${field.number} && tag.wireType == $PB_PKG.WireType.${field.type.wireType.name}") {
287-
code("$assignment decoder.read$encFuncName()")
292+
val raw = "decoder.read${field.type.decodeEncodeFuncName()}()"
293+
code("$assignment ${wrapperCtor(raw)}")
288294
}
289295

290296
is FieldType.List -> if (field.dec.isPacked) {
@@ -299,11 +305,22 @@ class ModelToKotlinCommonGenerator(
299305

300306
is FieldType.Enum -> whenCase("tag.fieldNr == ${field.number} && tag.wireType == $PB_PKG.WireType.VARINT") {
301307
val fromNum = "${fieldType.dec.name.safeFullName()}.fromNumber"
302-
code("$assignment $fromNum(decoder.read$encFuncName())")
308+
val raw = "$fromNum(decoder.read${field.type.decodeEncodeFuncName()}())"
309+
code("$assignment ${wrapperCtor(raw)}")
310+
}
311+
312+
is FieldType.OneOf -> {
313+
fieldType.dec.variants.forEach { variant ->
314+
val variantName = "${fieldType.dec.name.safeFullName()}.${variant.name}"
315+
readMatchCase(
316+
field = variant,
317+
assignment = assignment,
318+
wrapperCtor = { "$variantName($it)" }
319+
)
320+
}
303321
}
304322

305323
is FieldType.Map -> TODO()
306-
is FieldType.OneOf -> TODO()
307324
is FieldType.Reference -> TODO()
308325
}
309326
}
@@ -323,42 +340,54 @@ class ModelToKotlinCommonGenerator(
323340
val fieldName = field.name
324341
if (field.nullable) {
325342
scope("$fieldName?.also") {
326-
code(field.writeValue("it"))
343+
writeFieldValue(field, "it")
327344
}
328345
} else if (!field.dec.hasPresence()) {
329346
ifBranch(condition = field.defaultCheck(), ifBlock = {
330-
code(field.writeValue(field.name))
347+
writeFieldValue(field, field.name)
331348
})
332349
} else {
333-
code(field.writeValue(field.name))
350+
writeFieldValue(field, field.name)
334351
}
335352
}
336353
}
337354

338-
private fun FieldDeclaration.writeValue(variable: String): String {
339-
return when (val fieldType = type) {
340-
is FieldType.IntegralType -> "encoder.write${type.decodeEncodeFuncName()}(fieldNr = $number, value = $variable)"
341-
is FieldType.List -> when {
342-
dec.isPacked && packedFixedSize ->
343-
"encoder.writePacked${fieldType.value.decodeEncodeFuncName()}(fieldNr = $number, value = $variable)"
355+
private fun CodeGenerator.writeFieldValue(field: FieldDeclaration, valueVar: String) {
356+
var encFunc = field.type.decodeEncodeFuncName()
357+
val number = field.number
358+
when (val fieldType = field.type) {
359+
is FieldType.IntegralType -> code("encoder.write${encFunc!!}(fieldNr = $number, value = $valueVar)")
360+
is FieldType.List -> {
361+
encFunc = fieldType.value.decodeEncodeFuncName()
362+
when {
363+
field.dec.isPacked && field.packedFixedSize ->
364+
code("encoder.writePacked${encFunc!!}(fieldNr = $number, value = $valueVar)")
344365

345-
dec.isPacked && !packedFixedSize ->
346-
"encoder.writePacked${fieldType.value.decodeEncodeFuncName()}(fieldNr = $number, value = $variable, fieldSize = ${
347-
wireSizeCall(
348-
variable
366+
field.dec.isPacked && !field.packedFixedSize ->
367+
code(
368+
"encoder.writePacked${encFunc!!}(fieldNr = $number, value = $valueVar, fieldSize = ${
369+
field.wireSizeCall(valueVar)
370+
})"
349371
)
350-
})"
351372

352-
else ->
353-
"$variable.forEach { encoder.write${fieldType.value.decodeEncodeFuncName()}($number, it) }"
373+
else -> code("$valueVar.forEach { encoder.write${encFunc!!}($number, it) }")
374+
}
354375
}
355376

356-
is FieldType.Enum -> "encoder.write${type.decodeEncodeFuncName()}(fieldNr = $number, value = $variable.number)"
377+
is FieldType.Enum -> code("encoder.write${encFunc!!}(fieldNr = $number, value = ${valueVar}.number)")
378+
379+
is FieldType.OneOf -> whenBlock("val value = $valueVar") {
380+
fieldType.dec.variants.forEach { variant ->
381+
whenCase("is ${fieldType.dec.name.safeFullName()}.${variant.name}") {
382+
writeFieldValue(variant, "value.value")
383+
}
384+
}
385+
}
357386

358387
is FieldType.Map -> TODO()
359-
is FieldType.OneOf -> TODO()
360-
is FieldType.Reference -> "<TODO: Implement Reference writeValue()>"
388+
is FieldType.Reference -> code("<TODO: Implement Reference writeValue()>")
361389
}
390+
362391
}
363392

364393

@@ -370,9 +399,9 @@ class ModelToKotlinCommonGenerator(
370399
contextReceiver = "${enum.name.safeFullName()}.Companion",
371400
returnType = enum.name.safeFullName(),
372401
) {
373-
whenBlock(prefix = "return") {
402+
whenBlock(prefix = "return", condition = "number") {
374403
enum.originalEntries.forEach { entry ->
375-
whenCase("number == ${entry.dec.number}") {
404+
whenCase("${entry.dec.number}") {
376405
code("${entry.name}")
377406
}
378407
}
@@ -408,7 +437,8 @@ class ModelToKotlinCommonGenerator(
408437

409438

410439
private fun FieldDeclaration.wireSizeCall(variable: String): String {
411-
val sizeFunc = "$PB_PKG.WireSize.${type.decodeEncodeFuncName().replaceFirstChar { it.lowercase() }}($variable)"
440+
val sizeFunc =
441+
"$PB_PKG.WireSize.${type.decodeEncodeFuncName()!!.replaceFirstChar { it.lowercase() }}($variable)"
412442
return when (val fieldType = type) {
413443
is FieldType.IntegralType -> when {
414444
fieldType.wireType == WireType.FIXED32 -> "32"
@@ -444,7 +474,7 @@ class ModelToKotlinCommonGenerator(
444474
}
445475
}
446476

447-
private fun FieldType.decodeEncodeFuncName(): String = when (this) {
477+
private fun FieldType.decodeEncodeFuncName(): String? = when (this) {
448478
FieldType.IntegralType.STRING -> "String"
449479
FieldType.IntegralType.BYTES -> "Bytes"
450480
FieldType.IntegralType.BOOL -> "Bool"
@@ -462,9 +492,9 @@ class ModelToKotlinCommonGenerator(
462492
FieldType.IntegralType.SFIXED64 -> "SFixed64"
463493
is FieldType.List -> "Packed${value.decodeEncodeFuncName()}"
464494
is FieldType.Enum -> "Enum"
465-
is FieldType.Map -> error("No encoding/decoding function for map types")
466-
is FieldType.OneOf -> error("No encoding/decoding function for oneOf types")
467-
is FieldType.Reference -> error("No encoding/decoding function for sub message types")
495+
is FieldType.Map -> null
496+
is FieldType.OneOf -> null
497+
is FieldType.Reference -> null
468498
}
469499

470500
private fun FieldDeclaration.transformToFieldDeclaration(): String {
@@ -480,10 +510,7 @@ class ModelToKotlinCommonGenerator(
480510

481511
is FieldType.Enum -> type.dec.name.safeFullName()
482512

483-
is FieldType.OneOf -> {
484-
val value by type.value
485-
value.safeFullName()
486-
}
513+
is FieldType.OneOf -> type.dec.name.safeFullName()
487514

488515
is FieldType.IntegralType -> {
489516
type.fqName.simpleName

protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/codeRequestToModel.kt

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,20 +113,31 @@ private fun Descriptors.FileDescriptor.toModel(): FileDeclaration = cached {
113113

114114
private fun Descriptors.Descriptor.toModel(): MessageDeclaration = cached {
115115
var currPresenceIdx = 0
116-
val regularFields = fields
116+
var regularFields = fields
117117
// only fields that are not part of a oneOf declaration
118118
.filter { field -> field.realContainingOneof == null }
119119
.map {
120120
val presenceIdx = if (it.hasPresence()) currPresenceIdx++ else null
121121
it.toModel(presenceIdx = presenceIdx)
122122
}
123+
val oneOfs = oneofs.filter { it.fields[0].realContainingOneof != null }.map { it.toModel() }
124+
125+
regularFields = regularFields + oneOfs.map {
126+
FieldDeclaration(
127+
// TODO: Proper handling of this field name
128+
it.name.simpleName.lowercase(),
129+
FieldType.OneOf(it),
130+
doc = null,
131+
dec = it.variants.first().dec,
132+
)
133+
}
123134

124135
return MessageDeclaration(
125136
name = fqName(),
126137
presenceMaskSize = currPresenceIdx,
127138
actualFields = regularFields,
128139
// get all oneof declarations that are not created from an optional in proto3 https://github.com/googleapis/api-linter/issues/1323
129-
oneOfDeclarations = oneofs.filter { it.fields[0].realContainingOneof != null }.map { it.toModel() },
140+
oneOfDeclarations = oneOfs,
130141
enumDeclarations = enumTypes.map { it.toModel() },
131142
nestedDeclarations = nestedTypes.map { it.toModel() },
132143
doc = null,

protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/FieldType.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ sealed interface FieldType {
4242
override val wireType: WireType = WireType.LENGTH_DELIMITED
4343
}
4444

45-
data class OneOf(val value: Lazy<FqName>, val index: Int) : FieldType {
45+
data class OneOf(val dec: OneOfDeclaration) : FieldType {
4646
override val defaultValue: String = "null"
4747
override val wireType: WireType = WireType.LENGTH_DELIMITED
4848
}

protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/model/model.kt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,11 @@ data class FieldDeclaration(
7676
val packedFixedSize = type.wireType == WireType.FIXED64 || type.wireType == WireType.FIXED32
7777

7878
// aligns with edition settings and backward compatibility with proto2 and proto3
79-
val nullable: Boolean = dec.hasPresence() && !dec.isRequired && !dec.hasDefaultValue() && !dec.isRepeated
79+
val nullable: Boolean = (dec.hasPresence() && !dec.isRequired && !dec.hasDefaultValue()
80+
&& !dec.isRepeated // repeated fields cannot be nullable (just empty)
81+
&& dec.realContainingOneof == null // upper conditions would match oneof inner fields
82+
)
83+
|| type is FieldType.OneOf // all OneOf fields are nullable
8084
val number: Int = dec.number
8185
}
8286

0 commit comments

Comments
 (0)