@@ -9,6 +9,7 @@ package kotlinx.rpc.protobuf
9
9
import kotlinx.rpc.protobuf.CodeGenerator.DeclarationType
10
10
import kotlinx.rpc.protobuf.model.*
11
11
import org.slf4j.Logger
12
+ import kotlin.getValue
12
13
13
14
private const val RPC_INTERNAL_PACKAGE_SUFFIX = " _rpc_internal"
14
15
@@ -50,6 +51,7 @@ class ModelToKotlinGenerator(
50
51
generatePublicDeclaredEntities(this @generatePublicKotlinFile)
51
52
52
53
import(" kotlinx.rpc.internal.utils.*" )
54
+ import(" kotlinx.coroutines.flow.*" )
53
55
54
56
additionalPublicImports.forEach {
55
57
import(it)
@@ -76,6 +78,7 @@ class ModelToKotlinGenerator(
76
78
generateInternalDeclaredEntities(this @generateInternalKotlinFile)
77
79
78
80
import(" kotlinx.rpc.internal.utils.*" )
81
+ import(" kotlinx.coroutines.flow.*" )
79
82
80
83
additionalInternalImports.forEach {
81
84
import(it)
@@ -510,19 +513,22 @@ class ModelToKotlinGenerator(
510
513
code(" @kotlinx.rpc.grpc.annotations.Grpc" )
511
514
clazz(service.name.simpleName, declarationType = DeclarationType .Interface ) {
512
515
service.methods.forEach { method ->
513
- // no streaming for now
514
516
val inputType by method.inputType
515
517
val outputType by method.outputType
516
518
function(
517
519
name = method.name,
518
- modifiers = " suspend" ,
519
- args = " message: ${inputType.name.safeFullName()} " ,
520
- returnType = outputType.name.safeFullName(),
520
+ modifiers = if (method.serverStreaming) " " else " suspend" ,
521
+ args = " message: ${inputType.name.safeFullName().wrapInFlowIf(method.clientStreaming) } " ,
522
+ returnType = outputType.name.safeFullName().wrapInFlowIf(method.serverStreaming) ,
521
523
)
522
524
}
523
525
}
524
526
}
525
527
528
+ private fun String.wrapInFlowIf (condition : Boolean ): String {
529
+ return if (condition) " Flow<$this >" else this
530
+ }
531
+
526
532
private fun CodeGenerator.generateInternalService (service : ServiceDeclaration ) {
527
533
code(" @Suppress(\" unused\" , \" all\" )" )
528
534
clazz(
@@ -566,11 +572,23 @@ class ModelToKotlinGenerator(
566
572
567
573
function(
568
574
name = grpcName,
569
- modifiers = " override suspend" ,
570
- args = " request: ${inputType.toPlatformMessageType()} " ,
571
- returnType = outputType.toPlatformMessageType(),
575
+ modifiers = " override${ if (method.serverStreaming) " " else " suspend" } " ,
576
+ args = " request: ${inputType.toPlatformMessageType().wrapInFlowIf(method.clientStreaming) } " ,
577
+ returnType = outputType.toPlatformMessageType().wrapInFlowIf(method.serverStreaming) ,
572
578
) {
573
- code(" return impl.${method.name} (request.toKotlin()).toPlatform()" )
579
+ val toKotlin = if (method.clientStreaming) {
580
+ " map { it.toKotlin() }"
581
+ } else {
582
+ " toKotlin()"
583
+ }
584
+
585
+ val toPlatform = if (method.serverStreaming) {
586
+ " map { it.toPlatform() }"
587
+ } else {
588
+ " toPlatform()"
589
+ }
590
+
591
+ code(" return impl.${method.name} (request.${toKotlin} ).${toPlatform} " )
574
592
575
593
importRootDeclarationIfNeeded(inputType.name, " toPlatform" , true )
576
594
importRootDeclarationIfNeeded(outputType.name, " toKotlin" , true )
@@ -605,36 +623,72 @@ class ModelToKotlinGenerator(
605
623
typeParameters = " R" ,
606
624
returnType = " R" ,
607
625
) {
608
- code(" val message = rpcCall.parameters[0]" )
609
- code(" @Suppress(\" UNCHECKED_CAST\" )" )
610
- scope(" return when (rpcCall.callableName)" ) {
611
- service.methods.forEach { method ->
612
- val inputType by method.inputType
613
- val outputType by method.outputType
614
- val grpcName = method.name.replaceFirstChar { it.lowercase() }
615
- val result = " stub.$grpcName ((message as ${inputType.name.safeFullName()} ).toPlatform())"
616
- code(" \" ${method.name} \" -> $result .toKotlin() as R" )
617
-
618
- importRootDeclarationIfNeeded(inputType.name, " toPlatform" , true )
619
- importRootDeclarationIfNeeded(outputType.name, " toKotlin" , true )
620
- }
626
+ val methods = service.methods.filter { ! it.serverStreaming }
621
627
622
- code(" else -> error(\" Illegal call: \$ {rpcCall.callableName}\" )" )
628
+ if (methods.isEmpty()) {
629
+ code(" error(\" Illegal call: \$ {rpcCall.callableName}\" )" )
630
+ return @function
623
631
}
632
+
633
+ generateCallsImpls(methods)
624
634
}
625
635
626
636
function(
627
637
name = " callServerStreaming" ,
628
638
modifiers = " override" ,
629
639
args = " rpcCall: kotlinx.rpc.RpcCall" ,
630
640
typeParameters = " R" ,
631
- returnType = " kotlinx.coroutines.flow. Flow<R>" ,
641
+ returnType = " Flow<R>" ,
632
642
) {
633
- code(" error(\" Flow calls are not supported\" )" )
643
+ val methods = service.methods.filter { it.serverStreaming }
644
+
645
+ if (methods.isEmpty()) {
646
+ code(" error(\" Illegal streaming call: \$ {rpcCall.callableName}\" )" )
647
+ return @function
648
+ }
649
+
650
+ generateCallsImpls(methods)
634
651
}
635
652
}
636
653
}
637
654
655
+ private fun CodeGenerator.generateCallsImpls (
656
+ methods : List <MethodDeclaration >,
657
+ ) {
658
+ code(" val message = rpcCall.parameters[0]" )
659
+ code(" @Suppress(\" UNCHECKED_CAST\" )" )
660
+ scope(" return when (rpcCall.callableName)" ) {
661
+ methods.forEach { method ->
662
+ val inputType by method.inputType
663
+ val outputType by method.outputType
664
+ val grpcName = method.name.replaceFirstChar { it.lowercase() }
665
+
666
+ val toKotlin = if (method.serverStreaming) {
667
+ " map { it.toKotlin() }"
668
+ } else {
669
+ " toKotlin()"
670
+ }
671
+
672
+ val toPlatform = if (method.clientStreaming) {
673
+ " map { it.toPlatform() }"
674
+ } else {
675
+ " toPlatform()"
676
+ }
677
+
678
+ val argumentCast = inputType.name.safeFullName().wrapInFlowIf(method.clientStreaming)
679
+ val resultCast = " R" .wrapInFlowIf(method.serverStreaming)
680
+
681
+ val result = " stub.$grpcName ((message as $argumentCast ).${toPlatform} )"
682
+ code(" \" ${method.name} \" -> $result .${toKotlin} as $resultCast " )
683
+
684
+ importRootDeclarationIfNeeded(inputType.name, " toPlatform" , true )
685
+ importRootDeclarationIfNeeded(outputType.name, " toKotlin" , true )
686
+ }
687
+
688
+ code(" else -> error(\" Illegal call: \$ {rpcCall.callableName}\" )" )
689
+ }
690
+ }
691
+
638
692
private fun MessageDeclaration.toPlatformMessageType (): String {
639
693
return " ${outerClassName.safeFullName()} .${name.fullNestedName()} "
640
694
}
0 commit comments