Skip to content

Commit 2954560

Browse files
committed
KRPC-147 OneOf types in gRPC (#332)
1 parent e3020cd commit 2954560

File tree

9 files changed

+166
-79
lines changed

9 files changed

+166
-79
lines changed

protobuf-plugin/build.gradle.kts

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,7 @@ sourceSets {
3737
"**/enum_options.proto",
3838
"**/empty_deprecated.proto",
3939
"**/example.proto",
40-
"**/funny_types.proto",
4140
"**/multiple_files.proto",
42-
"**/one_of.proto",
4341
"**/options.proto",
4442
"**/with_comments.proto",
4543
)

protobuf-plugin/src/main/kotlin/kotlinx/rpc/protobuf/CodeGenerator.kt

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,11 @@ open class CodeGenerator(
7575
prefix: String,
7676
suffix: String = "",
7777
nlAfterClosed: Boolean = true,
78+
openingBracket: Boolean = true,
7879
block: (CodeGenerator.() -> Unit)? = null,
7980
) {
8081
addLine(prefix)
81-
scopeWithSuffix(suffix, nlAfterClosed, block)
82+
scopeWithSuffix(suffix, openingBracket, nlAfterClosed, block)
8283
}
8384

8485
internal fun ifBranch(
@@ -93,6 +94,7 @@ open class CodeGenerator(
9394

9495
private fun scopeWithSuffix(
9596
suffix: String = "",
97+
openingBracket: Boolean = true,
9698
nlAfterClosed: Boolean = true,
9799
block: (CodeGenerator.() -> Unit)? = null,
98100
) {
@@ -110,7 +112,9 @@ open class CodeGenerator(
110112
return
111113
}
112114

113-
append(" {")
115+
if (openingBracket) {
116+
append(" {")
117+
}
114118
newLine()
115119
append(nested.build().trimEnd())
116120
addLine("}$suffix")

protobuf-plugin/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinGenerator.kt

Lines changed: 82 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -132,11 +132,10 @@ class ModelToKotlinGenerator(
132132

133133
newLine()
134134

135-
// KRPC-147 OneOf Types
136-
// declaration.oneOfDeclarations.forEach { oneOf ->
137-
// generateOneOf(oneOf)
138-
// }
139-
//
135+
declaration.oneOfDeclarations.forEach { oneOf ->
136+
generateOneOfPublic(oneOf)
137+
}
138+
140139
declaration.nestedDeclarations.forEach { nested ->
141140
generatePublicMessage(nested)
142141
}
@@ -201,17 +200,29 @@ class ModelToKotlinGenerator(
201200
) {
202201
scope("return $platformType.newBuilder().apply", ".build()") {
203202
declaration.actualFields.forEach { field ->
204-
val uppercaseName = field.name.replaceFirstChar { ch -> ch.uppercase() }
205-
val setFieldCall = when (field.type) {
206-
is FieldType.List -> "addAll$uppercaseName"
207-
is FieldType.Map -> "putAll$uppercaseName"
208-
else -> "set$uppercaseName"
209-
}
203+
val setFieldCall = setFieldCall(field)
204+
205+
when {
206+
field.type is FieldType.OneOf -> {
207+
scope("this@toPlatform.${field.name}?.let { value ->", openingBracket = false) {
208+
scope("when (value)") {
209+
val oneOf = declaration.oneOfDeclarations[field.type.index]
210+
val oneOfName = oneOf.name.safeFullName()
211+
212+
oneOf.variants.forEach { variant ->
213+
code("is $oneOfName.${variant.name} -> ${setFieldCall(variant)}(value.value${variant.type.toPlatformCast()})")
214+
}
215+
}
216+
}
217+
}
218+
219+
field.nullable -> {
220+
code("this@toPlatform.${field.name}?.let { $setFieldCall(it${field.type.toPlatformCast()}) }")
221+
}
210222

211-
if (field.nullable) {
212-
code("this@toPlatform.${field.name}?.let { $setFieldCall(it${field.type.toPlatformCast()}) }")
213-
} else {
214-
code("$setFieldCall(this@toPlatform.${field.name}${field.type.toPlatformCast()})")
223+
else -> {
224+
code("$setFieldCall(this@toPlatform.${field.name}${field.type.toPlatformCast()})")
225+
}
215226
}
216227
}
217228
}
@@ -224,32 +235,64 @@ class ModelToKotlinGenerator(
224235
) {
225236
scope("return ${declaration.name.safeFullName()}") {
226237
declaration.actualFields.forEach { field ->
227-
val javaName = when (field.type) {
228-
is FieldType.List -> "${field.name}List"
229-
is FieldType.Map -> "${field.name}Map"
230-
else -> field.name
231-
}
238+
val javaName = fieldJavaName(field)
232239

233240
val getter = "this@toKotlin.$javaName${field.type.toKotlinCast()}"
234-
if (field.nullable) {
235-
ifBranch(
236-
prefix = "${field.name} = ",
237-
condition = "has${field.name.replaceFirstChar { ch -> ch.uppercase() }}()",
238-
ifBlock = {
239-
code(getter)
240-
},
241-
elseBlock = {
242-
code("null")
241+
when {
242+
field.type is FieldType.OneOf -> {
243+
val oneOf = declaration.oneOfDeclarations[field.type.index]
244+
val oneOfName = oneOf.name.safeFullName()
245+
246+
scope("${field.name} = when") {
247+
oneOf.variants.forEach { variant ->
248+
code("${hasFieldJavaMethod(variant)} -> $oneOfName.${variant.name}(this@toKotlin.${fieldJavaName(variant)}${variant.type.toKotlinCast()})")
249+
}
250+
code("else -> null")
243251
}
244-
)
245-
} else {
246-
code("${field.name} = $getter")
252+
}
253+
254+
field.nullable -> {
255+
ifBranch(
256+
prefix = "${field.name} = ",
257+
condition = hasFieldJavaMethod(field),
258+
ifBlock = {
259+
code(getter)
260+
},
261+
elseBlock = {
262+
code("null")
263+
}
264+
)
265+
}
266+
else -> {
267+
code("${field.name} = $getter")
268+
}
247269
}
248270
}
249271
}
250272
}
251273
}
252274

275+
private fun fieldJavaName(field: FieldDeclaration): String {
276+
val name = field.name.replaceFirstChar { ch -> ch.lowercase() }
277+
return when (field.type) {
278+
is FieldType.List -> "${name}List"
279+
is FieldType.Map -> "${name}Map"
280+
else -> name
281+
}
282+
}
283+
284+
private fun hasFieldJavaMethod(field: FieldDeclaration): String = "has${field.name.replaceFirstChar { ch -> ch.uppercase() }}()"
285+
286+
private fun setFieldCall(field: FieldDeclaration): String {
287+
val uppercaseName = field.name.replaceFirstChar { ch -> ch.uppercase() }
288+
val setFieldCall = when (field.type) {
289+
is FieldType.List -> "addAll$uppercaseName"
290+
is FieldType.Map -> "putAll$uppercaseName"
291+
else -> "set$uppercaseName"
292+
}
293+
return setFieldCall
294+
}
295+
253296
private fun FieldType.toPlatformCast(): String {
254297
return when (this) {
255298
FieldType.IntegralType.FIXED32 -> ".toInt()"
@@ -339,6 +382,11 @@ class ModelToKotlinGenerator(
339382
value.safeFullName()
340383
}
341384

385+
is FieldType.OneOf -> {
386+
val value by type.value
387+
value.safeFullName()
388+
}
389+
342390
is FieldType.IntegralType -> {
343391
type.fqName.simpleName
344392
}
@@ -377,11 +425,10 @@ class ModelToKotlinGenerator(
377425
return "$this${if (nullable) "?" else ""}"
378426
}
379427

380-
@Suppress("unused")
381-
private fun CodeGenerator.generateOneOf(declaration: OneOfDeclaration) {
428+
private fun CodeGenerator.generateOneOfPublic(declaration: OneOfDeclaration) {
382429
val interfaceName = declaration.name.simpleName
383430

384-
clazz(declaration.name.simpleName, "sealed", declarationType = DeclarationType.Interface) {
431+
clazz(interfaceName, "sealed", declarationType = DeclarationType.Interface) {
385432
declaration.variants.forEach { variant ->
386433
clazz(
387434
name = variant.name,
@@ -420,7 +467,6 @@ class ModelToKotlinGenerator(
420467
}
421468
}
422469

423-
@Suppress("unused")
424470
private fun CodeGenerator.generateToAndFromPlatformCastsEnum(declaration: EnumDeclaration) {
425471
val platformType = "${declaration.outerClassName.safeFullName()}.${declaration.name.fullNestedName()}"
426472

protobuf-plugin/src/main/kotlin/kotlinx/rpc/protobuf/ProtoToModelInterpreter.kt

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ class ProtoToModelInterpreter(
120120
outerClassName = outerClass,
121121
name = fqName,
122122
actualFields = fields,
123-
oneOfDeclarations = oneofDeclList.mapIndexedNotNull { i, desc -> desc.toModel(i, resolver) },
123+
oneOfDeclarations = oneofDeclList.mapIndexedNotNull { i, desc -> desc.toModel(i, resolver, fqName) },
124124
enumDeclarations = enumTypeList.map { it.toModel(resolver, outerClass, fqName) },
125125
nestedDeclarations = nestedDeclarations,
126126
deprecated = options.deprecated,
@@ -130,12 +130,13 @@ class ProtoToModelInterpreter(
130130
}
131131
}
132132

133-
private fun DescriptorProtos.DescriptorProto.resolveMapEntry(resolver: NameResolver): Lazy<FieldType.Map.Entry> = lazy {
134-
val keyType = fieldList[0].toModel(null, resolver) ?: error("Key type is null")
135-
val valueType = fieldList[1].toModel(null, resolver) ?: error("Value type is null")
133+
private fun DescriptorProtos.DescriptorProto.resolveMapEntry(resolver: NameResolver): Lazy<FieldType.Map.Entry> =
134+
lazy {
135+
val keyType = fieldList[0].toModel(null, resolver) ?: error("Key type is null")
136+
val valueType = fieldList[1].toModel(null, resolver) ?: error("Value type is null")
136137

137-
FieldType.Map.Entry(keyType.type, valueType.type)
138-
}
138+
FieldType.Map.Entry(keyType.type, valueType.type)
139+
}
139140

140141
private val oneOfFieldMembers = mutableMapOf<Int, MutableList<DescriptorProtos.FieldDescriptorProto>>()
141142

@@ -155,8 +156,8 @@ class ProtoToModelInterpreter(
155156
oneOfFieldMembers[oneofIndex] = mutableListOf<DescriptorProtos.FieldDescriptorProto>()
156157
.also { list -> list.add(this) }
157158

158-
TODO("KRPC-147 OneOf Types")
159-
// FieldType.Reference(oneOfName.fullProtoNameToKotlin(firstLetterUpper = true).toFqName())
159+
val name = oneOfName.fullProtoNameToKotlin(firstLetterUpper = true)
160+
FieldType.OneOf(lazy { resolver.resolve(name) }, oneofIndex)
160161
}
161162

162163
else -> {
@@ -168,7 +169,6 @@ class ProtoToModelInterpreter(
168169
return FieldDeclaration(
169170
name = oneOfName.removePrefix("_").fullProtoNameToKotlin(),
170171
type = fieldType,
171-
// TODO KRPC-147 OneOf Types: check nullability
172172
nullable = true,
173173
deprecated = options.deprecated,
174174
doc = null,
@@ -245,9 +245,6 @@ class ProtoToModelInterpreter(
245245
return wrapWithLabel(fieldType)
246246
}
247247

248-
private fun String.asReference(resolver: (String) -> FqName) =
249-
FieldType.Reference(lazy { resolver(this) })
250-
251248
private fun DescriptorProtos.FieldDescriptorProto.wrapWithLabel(fieldType: FieldType): FieldType {
252249
return when (label) {
253250
DescriptorProtos.FieldDescriptorProto.Label.LABEL_REPEATED -> {
@@ -261,19 +258,23 @@ class ProtoToModelInterpreter(
261258
}
262259
}
263260

264-
private fun DescriptorProtos.OneofDescriptorProto.toModel(index: Int, resolver: NameResolver): OneOfDeclaration? {
265-
// TODO KRPC-146 Nested Types: parent full type resolution
266-
// KRPC-147 OneOf Types: check fqName
261+
private fun DescriptorProtos.OneofDescriptorProto.toModel(
262+
index: Int,
263+
resolver: NameResolver,
264+
parent: FqName,
265+
): OneOfDeclaration? {
267266
val name = name.fullProtoNameToKotlin(firstLetterUpper = true)
268-
val fqName = resolver.declarationFqName(name, packageName)
267+
val fqName = resolver.declarationFqName(name, parent)
269268

270269
val fields = oneOfFieldMembers[index] ?: return null
270+
271+
val fieldResolver = resolver.withScope(fqName)
271272
return OneOfDeclaration(
272273
name = fqName,
273274
variants = fields.map { field ->
274275
FieldDeclaration(
275276
name = field.name.fullProtoNameToKotlin(firstLetterUpper = true),
276-
type = field.fieldType(resolver),
277+
type = field.fieldType(fieldResolver),
277278
nullable = false,
278279
deprecated = field.options.deprecated,
279280
doc = null,

protobuf-plugin/src/main/kotlin/kotlinx/rpc/protobuf/model/FieldDeclaration.kt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ sealed interface FieldType {
2929
override val defaultValue: String = "null"
3030
}
3131

32+
data class OneOf(val value: Lazy<FqName>, val index: Int) : FieldType {
33+
override val defaultValue: String = "null"
34+
}
35+
3236
enum class IntegralType(simpleName: String, override val defaultValue: String) : FieldType {
3337
STRING("String", "\"\""),
3438
BYTES("ByteArray", "byteArrayOf()"),

protobuf-plugin/src/test/kotlin/kotlinx/rpc/protobuf/test/TestReferenceService.kt

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import kotlinx.rpc.RpcServer
1212
import kotlinx.rpc.registerService
1313
import kotlinx.rpc.withService
1414
import kotlin.test.Test
15+
import kotlin.test.assertContentEquals
1516
import kotlin.test.assertEquals
1617
import kotlin.test.assertNotNull
1718

@@ -45,6 +46,10 @@ class ReferenceTestServiceImpl : ReferenceTestService {
4546
override suspend fun Map(message: TestMap): TestMap {
4647
return message
4748
}
49+
50+
override suspend fun OneOf(message: OneOf): OneOf {
51+
return message
52+
}
4853
}
4954

5055
class TestReferenceService : GrpcServerTest() {
@@ -217,4 +222,48 @@ class TestReferenceService : GrpcServerTest() {
217222
assertEquals(mapOf("1" to 2L, "2" to 1L), result.primitives)
218223
assertEquals(mapOf("ref" to 42), result.references.mapValues { it.value.other.field })
219224
}
225+
226+
@Test
227+
fun testOneOf() = runGrpcTest { grpcClient ->
228+
val service = grpcClient.withService<ReferenceTestService>()
229+
val result1 = service.OneOf(OneOf {
230+
primitives = OneOf.Primitives.StringValue("42")
231+
references = OneOf.References.Other(kotlinx.rpc.protobuf.test.Other {
232+
field = 42
233+
})
234+
mixed = OneOf.Mixed.Int64(42L)
235+
single = OneOf.Single.Bytes(byteArrayOf(42))
236+
})
237+
238+
assertEquals("42", (result1.primitives as OneOf.Primitives.StringValue).value)
239+
assertEquals(42, (result1.references as OneOf.References.Other).value.field)
240+
assertEquals(42L, (result1.mixed as OneOf.Mixed.Int64).value)
241+
assertContentEquals(byteArrayOf(42), (result1.single as OneOf.Single.Bytes).value)
242+
243+
val result2 = service.OneOf(OneOf {
244+
primitives = OneOf.Primitives.Bool(true)
245+
references = OneOf.References.InnerReferences(kotlinx.rpc.protobuf.test.References {
246+
other = kotlinx.rpc.protobuf.test.Other {
247+
field = 42
248+
}
249+
})
250+
mixed = OneOf.Mixed.AllPrimitives(AllPrimitives {
251+
string = "42"
252+
})
253+
})
254+
255+
assertEquals(true, (result2.primitives as OneOf.Primitives.Bool).value)
256+
assertEquals(42, (result2.references as OneOf.References.InnerReferences).value.other.field)
257+
assertEquals("42", (result2.mixed as OneOf.Mixed.AllPrimitives).value.string)
258+
assertEquals(null, result2.single)
259+
260+
val result3 = service.OneOf(OneOf {
261+
primitives = OneOf.Primitives.Int32(42)
262+
})
263+
264+
assertEquals(42, (result3.primitives as OneOf.Primitives.Int32).value)
265+
assertEquals(null, result3.references)
266+
assertEquals(null, result3.mixed)
267+
assertEquals(null, result3.single)
268+
}
220269
}

protobuf-plugin/src/test/proto/funny_types.proto

Lines changed: 0 additions & 17 deletions
This file was deleted.

0 commit comments

Comments
 (0)