Skip to content

Commit 6eab5d9

Browse files
committed
Support streaming (#389)
1 parent 86c1a26 commit 6eab5d9

File tree

5 files changed

+188
-26
lines changed

5 files changed

+188
-26
lines changed

docs/pages/kotlinx-rpc/topics/platforms.topic

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,14 @@
7777
<td><list><li>apple<list><li>ios<list><li>iosArm64</li><li>iosSimulatorArm64</li><li>iosX64</li></list></li><li>macos<list><li>macosArm64</li><li>macosX64</li></list></li><li>watchos<list><li>watchosArm32</li><li>watchosArm64</li><li>watchosDeviceArm64</li><li>watchosSimulatorArm64</li><li>watchosX64</li></list></li><li>tvos<list><li>tvosArm64</li><li>tvosSimulatorArm64</li><li>tvosX64</li></list></li></list></li><li>linux<list><li>linuxArm64</li><li>linuxX64</li></list></li><li>windows<list><li>mingwX64</li></list></li></list></td>
7878
</tr>
7979

80+
<tr>
81+
<td>protobuf-plugin</td>
82+
<td>Jvm Only</td>
83+
<td>-</td>
84+
<td>-</td>
85+
<td>-</td>
86+
</tr>
87+
8088
<tr>
8189
<td>utils</td>
8290
<td>jvm</td>
@@ -85,6 +93,14 @@
8593
<td><list><li>apple<list><li>ios<list><li>iosArm64</li><li>iosSimulatorArm64</li><li>iosX64</li></list></li><li>macos<list><li>macosArm64</li><li>macosX64</li></list></li><li>watchos<list><li>watchosArm32</li><li>watchosArm64</li><li>watchosDeviceArm64</li><li>watchosSimulatorArm64</li><li>watchosX64</li></list></li><li>tvos<list><li>tvosArm64</li><li>tvosSimulatorArm64</li><li>tvosX64</li></list></li></list></li><li>linux<list><li>linuxArm64</li><li>linuxX64</li></list></li><li>windows<list><li>mingwX64</li></list></li></list></td>
8694
</tr>
8795

96+
<tr>
97+
<td>grpc-core</td>
98+
<td>jvm</td>
99+
<td><list><li>browser</li><li>node</li></list></td>
100+
<td><list><li>wasmJs<list><li>browser</li><li>d8</li><li>node</li></list></li></list></td>
101+
<td><list><li>apple<list><li>ios<list><li>iosArm64</li><li>iosSimulatorArm64</li><li>iosX64</li></list></li><li>macos<list><li>macosArm64</li><li>macosX64</li></list></li><li>watchos<list><li>watchosArm32</li><li>watchosArm64</li><li>watchosDeviceArm64</li><li>watchosSimulatorArm64</li><li>watchosX64</li></list></li><li>tvos<list><li>tvosArm64</li><li>tvosSimulatorArm64</li><li>tvosX64</li></list></li></list></li><li>linux<list><li>linuxArm64</li><li>linuxX64</li></list></li><li>windows<list><li>mingwX64</li></list></li></list></td>
102+
</tr>
103+
88104
<tr>
89105
<td>krpc-client</td>
90106
<td>jvm</td>

protobuf-plugin/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinGenerator.kt

Lines changed: 78 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ package kotlinx.rpc.protobuf
99
import kotlinx.rpc.protobuf.CodeGenerator.DeclarationType
1010
import kotlinx.rpc.protobuf.model.*
1111
import org.slf4j.Logger
12+
import kotlin.getValue
1213

1314
private const val RPC_INTERNAL_PACKAGE_SUFFIX = "_rpc_internal"
1415

@@ -50,6 +51,7 @@ class ModelToKotlinGenerator(
5051
generatePublicDeclaredEntities(this@generatePublicKotlinFile)
5152

5253
import("kotlinx.rpc.internal.utils.*")
54+
import("kotlinx.coroutines.flow.*")
5355

5456
additionalPublicImports.forEach {
5557
import(it)
@@ -76,6 +78,7 @@ class ModelToKotlinGenerator(
7678
generateInternalDeclaredEntities(this@generateInternalKotlinFile)
7779

7880
import("kotlinx.rpc.internal.utils.*")
81+
import("kotlinx.coroutines.flow.*")
7982

8083
additionalInternalImports.forEach {
8184
import(it)
@@ -510,19 +513,22 @@ class ModelToKotlinGenerator(
510513
code("@kotlinx.rpc.grpc.annotations.Grpc")
511514
clazz(service.name.simpleName, declarationType = DeclarationType.Interface) {
512515
service.methods.forEach { method ->
513-
// no streaming for now
514516
val inputType by method.inputType
515517
val outputType by method.outputType
516518
function(
517519
name = method.name,
518-
modifiers = "suspend",
519-
args = "message: ${inputType.name.safeFullName()}",
520-
returnType = outputType.name.safeFullName(),
520+
modifiers = if (method.serverStreaming) "" else "suspend",
521+
args = "message: ${inputType.name.safeFullName().wrapInFlowIf(method.clientStreaming)}",
522+
returnType = outputType.name.safeFullName().wrapInFlowIf(method.serverStreaming),
521523
)
522524
}
523525
}
524526
}
525527

528+
private fun String.wrapInFlowIf(condition: Boolean): String {
529+
return if (condition) "Flow<$this>" else this
530+
}
531+
526532
private fun CodeGenerator.generateInternalService(service: ServiceDeclaration) {
527533
code("@Suppress(\"unused\", \"all\")")
528534
clazz(
@@ -566,11 +572,23 @@ class ModelToKotlinGenerator(
566572

567573
function(
568574
name = grpcName,
569-
modifiers = "override suspend",
570-
args = "request: ${inputType.toPlatformMessageType()}",
571-
returnType = outputType.toPlatformMessageType(),
575+
modifiers = "override${if (method.serverStreaming) "" else " suspend"}",
576+
args = "request: ${inputType.toPlatformMessageType().wrapInFlowIf(method.clientStreaming)}",
577+
returnType = outputType.toPlatformMessageType().wrapInFlowIf(method.serverStreaming),
572578
) {
573-
code("return impl.${method.name}(request.toKotlin()).toPlatform()")
579+
val toKotlin = if (method.clientStreaming) {
580+
"map { it.toKotlin() }"
581+
} else {
582+
"toKotlin()"
583+
}
584+
585+
val toPlatform = if (method.serverStreaming) {
586+
"map { it.toPlatform() }"
587+
} else {
588+
"toPlatform()"
589+
}
590+
591+
code("return impl.${method.name}(request.${toKotlin}).${toPlatform}")
574592

575593
importRootDeclarationIfNeeded(inputType.name, "toPlatform", true)
576594
importRootDeclarationIfNeeded(outputType.name, "toKotlin", true)
@@ -605,36 +623,72 @@ class ModelToKotlinGenerator(
605623
typeParameters = "R",
606624
returnType = "R",
607625
) {
608-
code("val message = rpcCall.parameters[0]")
609-
code("@Suppress(\"UNCHECKED_CAST\")")
610-
scope("return when (rpcCall.callableName)") {
611-
service.methods.forEach { method ->
612-
val inputType by method.inputType
613-
val outputType by method.outputType
614-
val grpcName = method.name.replaceFirstChar { it.lowercase() }
615-
val result = "stub.$grpcName((message as ${inputType.name.safeFullName()}).toPlatform())"
616-
code("\"${method.name}\" -> $result.toKotlin() as R")
617-
618-
importRootDeclarationIfNeeded(inputType.name, "toPlatform", true)
619-
importRootDeclarationIfNeeded(outputType.name, "toKotlin", true)
620-
}
626+
val methods = service.methods.filter { !it.serverStreaming }
621627

622-
code("else -> error(\"Illegal call: \${rpcCall.callableName}\")")
628+
if (methods.isEmpty()) {
629+
code("error(\"Illegal call: \${rpcCall.callableName}\")")
630+
return@function
623631
}
632+
633+
generateCallsImpls(methods)
624634
}
625635

626636
function(
627637
name = "callServerStreaming",
628638
modifiers = "override",
629639
args = "rpcCall: kotlinx.rpc.RpcCall",
630640
typeParameters = "R",
631-
returnType = "kotlinx.coroutines.flow.Flow<R>",
641+
returnType = "Flow<R>",
632642
) {
633-
code("error(\"Flow calls are not supported\")")
643+
val methods = service.methods.filter { it.serverStreaming }
644+
645+
if (methods.isEmpty()) {
646+
code("error(\"Illegal streaming call: \${rpcCall.callableName}\")")
647+
return@function
648+
}
649+
650+
generateCallsImpls(methods)
634651
}
635652
}
636653
}
637654

655+
private fun CodeGenerator.generateCallsImpls(
656+
methods: List<MethodDeclaration>,
657+
) {
658+
code("val message = rpcCall.parameters[0]")
659+
code("@Suppress(\"UNCHECKED_CAST\")")
660+
scope("return when (rpcCall.callableName)") {
661+
methods.forEach { method ->
662+
val inputType by method.inputType
663+
val outputType by method.outputType
664+
val grpcName = method.name.replaceFirstChar { it.lowercase() }
665+
666+
val toKotlin = if (method.serverStreaming) {
667+
"map { it.toKotlin() }"
668+
} else {
669+
"toKotlin()"
670+
}
671+
672+
val toPlatform = if (method.clientStreaming) {
673+
"map { it.toPlatform() }"
674+
} else {
675+
"toPlatform()"
676+
}
677+
678+
val argumentCast = inputType.name.safeFullName().wrapInFlowIf(method.clientStreaming)
679+
val resultCast = "R".wrapInFlowIf(method.serverStreaming)
680+
681+
val result = "stub.$grpcName((message as $argumentCast).${toPlatform})"
682+
code("\"${method.name}\" -> $result.${toKotlin} as $resultCast")
683+
684+
importRootDeclarationIfNeeded(inputType.name, "toPlatform", true)
685+
importRootDeclarationIfNeeded(outputType.name, "toKotlin", true)
686+
}
687+
688+
code("else -> error(\"Illegal call: \${rpcCall.callableName}\")")
689+
}
690+
}
691+
638692
private fun MessageDeclaration.toPlatformMessageType(): String {
639693
return "${outerClassName.safeFullName()}.${name.fullNestedName()}"
640694
}
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
/*
2+
* Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
3+
*/
4+
5+
package kotlinx.rpc.protobuf.test
6+
7+
import StreamingTestService
8+
import kotlinx.coroutines.flow.Flow
9+
import kotlinx.coroutines.flow.collectIndexed
10+
import kotlinx.coroutines.flow.flow
11+
import kotlinx.coroutines.flow.last
12+
import kotlinx.coroutines.flow.toList
13+
import kotlinx.rpc.RpcServer
14+
import kotlinx.rpc.registerService
15+
import kotlinx.rpc.withService
16+
import kotlin.test.Test
17+
import kotlin.test.assertEquals
18+
19+
class StreamingTestServiceImpl : StreamingTestService {
20+
override fun Server(message: References): Flow<References> {
21+
return flow { emit(message); emit(message); emit(message) }
22+
}
23+
24+
override suspend fun Client(message: Flow<References>): References {
25+
return message.last()
26+
}
27+
28+
override fun Bidi(message: Flow<References>): Flow<References> {
29+
return message
30+
}
31+
}
32+
33+
class StreamingTest : GrpcServerTest() {
34+
override fun RpcServer.registerServices() {
35+
registerService<StreamingTestService> { StreamingTestServiceImpl() }
36+
}
37+
38+
@Test
39+
fun testServerStreaming() = runGrpcTest { grpcClient ->
40+
val service = grpcClient.withService<StreamingTestService>()
41+
service.Server(References {
42+
other = Other {
43+
field= 42
44+
}
45+
}).toList().run {
46+
assertEquals(3, size)
47+
48+
forEach {
49+
assertEquals(42, it.other.field)
50+
}
51+
}
52+
}
53+
54+
@Test
55+
fun testClientStreaming() = runGrpcTest { grpcClient ->
56+
val service = grpcClient.withService<StreamingTestService>()
57+
val result = service.Client(flow {
58+
repeat(3) {
59+
emit(References {
60+
other = Other {
61+
field = 42 + it
62+
}
63+
})
64+
}
65+
})
66+
67+
assertEquals(44, result.other.field)
68+
}
69+
70+
@Test
71+
fun testBidiStreaming() = runGrpcTest { grpcClient ->
72+
val service = grpcClient.withService<StreamingTestService>()
73+
service.Bidi(flow {
74+
repeat(3) {
75+
emit(References {
76+
other = Other {
77+
field = 42 + it
78+
}
79+
})
80+
}
81+
}).collectIndexed { i, it ->
82+
assertEquals(42 + i, it.other.field)
83+
}
84+
}
85+
}

protobuf-plugin/src/test/proto/reference.proto

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
syntax = "proto3";
22

3-
import "all_primitives.proto";
4-
53
message Other {
64
string arg = 1;
75
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
syntax = "proto3";
2+
3+
import "reference_package.proto";
4+
5+
service StreamingTestService {
6+
rpc Server(kotlinx.rpc.protobuf.test.References) returns (stream kotlinx.rpc.protobuf.test.References);
7+
rpc Client(stream kotlinx.rpc.protobuf.test.References) returns (kotlinx.rpc.protobuf.test.References);
8+
rpc Bidi(stream kotlinx.rpc.protobuf.test.References) returns (stream kotlinx.rpc.protobuf.test.References);
9+
}

0 commit comments

Comments
 (0)