@@ -2,25 +2,33 @@ package misk.grpc
22
33import com.squareup.wire.ProtoAdapter
44import com.squareup.wire.WireRpc
5+ import jakarta.inject.Inject
6+ import jakarta.inject.Singleton
7+ import kotlinx.coroutines.Dispatchers
8+ import kotlinx.coroutines.channels.BufferOverflow
9+ import kotlinx.coroutines.channels.Channel
10+ import kotlinx.coroutines.channels.Channel.Factory.RENDEZVOUS
511import misk.Action
612import misk.web.DispatchMechanism
713import misk.web.FeatureBinding
814import misk.web.FeatureBinding.Claimer
915import misk.web.FeatureBinding.Subject
10- import misk.web.Grpc
1116import misk.web.PathPattern
17+ import misk.web.WebConfig
1218import misk.web.actions.findAnnotationWithOverrides
1319import misk.web.mediatype.MediaTypes
1420import java.lang.reflect.Type
15- import jakarta.inject.Inject
16- import jakarta.inject.Singleton
21+ import kotlin.coroutines.CoroutineContext
1722
1823internal class GrpcFeatureBinding (
1924 private val requestAdapter : ProtoAdapter <Any >,
2025 private val responseAdapter : ProtoAdapter <Any >,
2126 private val streamingRequest : Boolean ,
22- private val streamingResponse : Boolean
27+ private val streamingResponse : Boolean ,
28+ private val isSuspend : Boolean ,
29+ private val grpcMessageSourceChannelContext : CoroutineContext
2330) : FeatureBinding {
31+
2432 override fun beforeCall (subject : Subject ) {
2533 val requestBody = subject.takeRequestBody()
2634 val messageSource = GrpcMessageSource (
@@ -29,7 +37,19 @@ internal class GrpcFeatureBinding(
2937 )
3038
3139 if (streamingRequest) {
32- subject.setParameter(0 , messageSource)
40+ val param: Any = if (isSuspend) {
41+ GrpcMessageSourceChannel (
42+ channel = Channel (
43+ capacity = RENDEZVOUS ,
44+ onBufferOverflow = BufferOverflow .SUSPEND ,
45+ ),
46+ source = messageSource,
47+ coroutineContext = grpcMessageSourceChannelContext,
48+ )
49+ } else {
50+ messageSource
51+ }
52+ subject.setParameter(0 , param)
3353 } else {
3454 val request = messageSource.read()!!
3555 subject.setParameter(0 , request)
@@ -38,9 +58,20 @@ internal class GrpcFeatureBinding(
3858 if (streamingResponse) {
3959 val responseBody = subject.takeResponseBody()
4060 val messageSink = GrpcMessageSink (responseBody, responseAdapter, grpcEncoding = " identity" )
61+ val param: Any = if (isSuspend) {
62+ GrpcMessageSinkChannel (
63+ channel = Channel (
64+ capacity = RENDEZVOUS ,
65+ onBufferOverflow = BufferOverflow .SUSPEND ,
66+ ),
67+ sink = messageSink,
68+ )
69+ } else {
70+ messageSink
71+ }
4172
4273 // It's a streaming response, give the call a SendChannel to write to.
43- subject.setParameter(1 , messageSink )
74+ subject.setParameter(1 , param )
4475 setResponseHeaders(subject)
4576 }
4677 }
@@ -72,7 +103,19 @@ internal class GrpcFeatureBinding(
72103 }
73104
74105 @Singleton
75- class Factory @Inject internal constructor() : FeatureBinding.Factory {
106+ class Factory @Inject internal constructor(
107+ webConfig : WebConfig
108+ ) : FeatureBinding.Factory {
109+
110+ // This dispatcher is sized to the jetty thread pool size to make sure that
111+ // no requests that are currently scheduled on a jetty thread are ever blocked
112+ // from reading a streaming request
113+ private val grpcMessageSourceChannelDispatcher =
114+ Dispatchers .IO .limitedParallelism(
115+ parallelism = webConfig.jetty_max_thread_pool_size,
116+ name = " GrpcMessageSourceChannel.bridgeFromSource"
117+ )
118+
76119 override fun create (
77120 action : Action ,
78121 pathPattern : PathPattern ,
@@ -96,30 +139,35 @@ internal class GrpcFeatureBinding(
96139 val responseAdapter = if (action.parameters.size == 2 ) {
97140 claimer.claimParameter(1 )
98141 val responseType: Type = action.parameters[1 ].type.streamElementType()
99- ? : error(" @Grpc function's second parameter should be a MessageSource : $action " )
142+ ? : error(" @Grpc function's second parameter should be a MessageSink(blocking) or SendChannel(suspending) : $action " )
100143 @Suppress(" UNCHECKED_CAST" ) // Assume it's a proto type.
101144 ProtoAdapter .get(responseType as Class <Any >)
102145 } else {
103146 claimer.claimReturnValue()
104147 @Suppress(" UNCHECKED_CAST" ) // Assume it's a proto type.
105148 ProtoAdapter .get(wireAnnotation.responseAdapter) as ProtoAdapter <Any >
106149 }
150+ val isSuspending = action.function.isSuspend
107151
108152 return if (streamingRequestType != null ) {
109153 @Suppress(" UNCHECKED_CAST" ) // Assume it's a proto type.
110154 GrpcFeatureBinding (
111155 requestAdapter = ProtoAdapter .get(streamingRequestType as Class <Any >),
112156 responseAdapter = responseAdapter,
113157 streamingRequest = true ,
114- streamingResponse = streamingResponse
158+ streamingResponse = streamingResponse,
159+ isSuspend = isSuspending,
160+ grpcMessageSourceChannelContext = grpcMessageSourceChannelDispatcher,
115161 )
116162 } else {
117163 @Suppress(" UNCHECKED_CAST" ) // Assume it's a proto type.
118164 GrpcFeatureBinding (
119165 requestAdapter = ProtoAdapter .get(wireAnnotation.requestAdapter) as ProtoAdapter <Any >,
120166 responseAdapter = responseAdapter,
121167 streamingRequest = false ,
122- streamingResponse = streamingResponse
168+ streamingResponse = streamingResponse,
169+ isSuspend = isSuspending,
170+ grpcMessageSourceChannelContext = grpcMessageSourceChannelDispatcher,
123171 )
124172 }
125173 }
0 commit comments