@@ -9,6 +9,7 @@ package kotlinx.rpc.protobuf
99import kotlinx.rpc.protobuf.CodeGenerator.DeclarationType
1010import kotlinx.rpc.protobuf.model.*
1111import org.slf4j.Logger
12+ import kotlin.getValue
1213
1314private const val RPC_INTERNAL_PACKAGE_SUFFIX = " _rpc_internal"
1415
@@ -50,6 +51,7 @@ class ModelToKotlinGenerator(
5051 generatePublicDeclaredEntities(this @generatePublicKotlinFile)
5152
5253 import(" kotlinx.rpc.internal.utils.*" )
54+ import(" kotlinx.coroutines.flow.*" )
5355
5456 additionalPublicImports.forEach {
5557 import(it)
@@ -76,6 +78,7 @@ class ModelToKotlinGenerator(
7678 generateInternalDeclaredEntities(this @generateInternalKotlinFile)
7779
7880 import(" kotlinx.rpc.internal.utils.*" )
81+ import(" kotlinx.coroutines.flow.*" )
7982
8083 additionalInternalImports.forEach {
8184 import(it)
@@ -510,19 +513,22 @@ class ModelToKotlinGenerator(
510513 code(" @kotlinx.rpc.grpc.annotations.Grpc" )
511514 clazz(service.name.simpleName, declarationType = DeclarationType .Interface ) {
512515 service.methods.forEach { method ->
513- // no streaming for now
514516 val inputType by method.inputType
515517 val outputType by method.outputType
516518 function(
517519 name = method.name,
518- modifiers = " suspend" ,
519- args = " message: ${inputType.name.safeFullName()} " ,
520- returnType = outputType.name.safeFullName(),
520+ modifiers = if (method.serverStreaming) " " else " suspend" ,
521+ args = " message: ${inputType.name.safeFullName().wrapInFlowIf(method.clientStreaming) } " ,
522+ returnType = outputType.name.safeFullName().wrapInFlowIf(method.serverStreaming) ,
521523 )
522524 }
523525 }
524526 }
525527
528+ private fun String.wrapInFlowIf (condition : Boolean ): String {
529+ return if (condition) " Flow<$this >" else this
530+ }
531+
526532 private fun CodeGenerator.generateInternalService (service : ServiceDeclaration ) {
527533 code(" @Suppress(\" unused\" , \" all\" )" )
528534 clazz(
@@ -566,11 +572,23 @@ class ModelToKotlinGenerator(
566572
567573 function(
568574 name = grpcName,
569- modifiers = " override suspend" ,
570- args = " request: ${inputType.toPlatformMessageType()} " ,
571- returnType = outputType.toPlatformMessageType(),
575+ modifiers = " override${ if (method.serverStreaming) " " else " suspend" } " ,
576+ args = " request: ${inputType.toPlatformMessageType().wrapInFlowIf(method.clientStreaming) } " ,
577+ returnType = outputType.toPlatformMessageType().wrapInFlowIf(method.serverStreaming) ,
572578 ) {
573- code(" return impl.${method.name} (request.toKotlin()).toPlatform()" )
579+ val toKotlin = if (method.clientStreaming) {
580+ " map { it.toKotlin() }"
581+ } else {
582+ " toKotlin()"
583+ }
584+
585+ val toPlatform = if (method.serverStreaming) {
586+ " map { it.toPlatform() }"
587+ } else {
588+ " toPlatform()"
589+ }
590+
591+ code(" return impl.${method.name} (request.${toKotlin} ).${toPlatform} " )
574592
575593 importRootDeclarationIfNeeded(inputType.name, " toPlatform" , true )
576594 importRootDeclarationIfNeeded(outputType.name, " toKotlin" , true )
@@ -605,36 +623,72 @@ class ModelToKotlinGenerator(
605623 typeParameters = " R" ,
606624 returnType = " R" ,
607625 ) {
608- code(" val message = rpcCall.parameters[0]" )
609- code(" @Suppress(\" UNCHECKED_CAST\" )" )
610- scope(" return when (rpcCall.callableName)" ) {
611- service.methods.forEach { method ->
612- val inputType by method.inputType
613- val outputType by method.outputType
614- val grpcName = method.name.replaceFirstChar { it.lowercase() }
615- val result = " stub.$grpcName ((message as ${inputType.name.safeFullName()} ).toPlatform())"
616- code(" \" ${method.name} \" -> $result .toKotlin() as R" )
617-
618- importRootDeclarationIfNeeded(inputType.name, " toPlatform" , true )
619- importRootDeclarationIfNeeded(outputType.name, " toKotlin" , true )
620- }
626+ val methods = service.methods.filter { ! it.serverStreaming }
621627
622- code(" else -> error(\" Illegal call: \$ {rpcCall.callableName}\" )" )
628+ if (methods.isEmpty()) {
629+ code(" error(\" Illegal call: \$ {rpcCall.callableName}\" )" )
630+ return @function
623631 }
632+
633+ generateCallsImpls(methods)
624634 }
625635
626636 function(
627637 name = " callServerStreaming" ,
628638 modifiers = " override" ,
629639 args = " rpcCall: kotlinx.rpc.RpcCall" ,
630640 typeParameters = " R" ,
631- returnType = " kotlinx.coroutines.flow. Flow<R>" ,
641+ returnType = " Flow<R>" ,
632642 ) {
633- code(" error(\" Flow calls are not supported\" )" )
643+ val methods = service.methods.filter { it.serverStreaming }
644+
645+ if (methods.isEmpty()) {
646+ code(" error(\" Illegal streaming call: \$ {rpcCall.callableName}\" )" )
647+ return @function
648+ }
649+
650+ generateCallsImpls(methods)
634651 }
635652 }
636653 }
637654
655+ private fun CodeGenerator.generateCallsImpls (
656+ methods : List <MethodDeclaration >,
657+ ) {
658+ code(" val message = rpcCall.parameters[0]" )
659+ code(" @Suppress(\" UNCHECKED_CAST\" )" )
660+ scope(" return when (rpcCall.callableName)" ) {
661+ methods.forEach { method ->
662+ val inputType by method.inputType
663+ val outputType by method.outputType
664+ val grpcName = method.name.replaceFirstChar { it.lowercase() }
665+
666+ val toKotlin = if (method.serverStreaming) {
667+ " map { it.toKotlin() }"
668+ } else {
669+ " toKotlin()"
670+ }
671+
672+ val toPlatform = if (method.clientStreaming) {
673+ " map { it.toPlatform() }"
674+ } else {
675+ " toPlatform()"
676+ }
677+
678+ val argumentCast = inputType.name.safeFullName().wrapInFlowIf(method.clientStreaming)
679+ val resultCast = " R" .wrapInFlowIf(method.serverStreaming)
680+
681+ val result = " stub.$grpcName ((message as $argumentCast ).${toPlatform} )"
682+ code(" \" ${method.name} \" -> $result .${toKotlin} as $resultCast " )
683+
684+ importRootDeclarationIfNeeded(inputType.name, " toPlatform" , true )
685+ importRootDeclarationIfNeeded(outputType.name, " toKotlin" , true )
686+ }
687+
688+ code(" else -> error(\" Illegal call: \$ {rpcCall.callableName}\" )" )
689+ }
690+ }
691+
638692 private fun MessageDeclaration.toPlatformMessageType (): String {
639693 return " ${outerClassName.safeFullName()} .${name.fullNestedName()} "
640694 }
0 commit comments