Skip to content

Commit ec17b53

Browse files
jclynesvc-squareup-copybara
authored andcommitted
Adds support for wire.kotlin rpcCallStyle =
"suspending" GitOrigin-RevId: 05fa1f17678cdc51f2f2e57acffd638038fc81d4
1 parent 75884b1 commit ec17b53

File tree

16 files changed

+492
-23
lines changed

16 files changed

+492
-23
lines changed

docs/actions.md

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,41 @@ class GreeterActionModule : KAbstractModule() {
187187
}
188188
```
189189

190+
Misk also supports `rpcCallStyle = "suspending"` for suspending gRPC actions.This is the preferred way to generate server
191+
actions if you intend on using coroutines to implement the business logic of your action. See [coroutines](coroutines.md) for more information.
192+
193+
```kotlin
194+
wire {
195+
sourcePath {
196+
srcDir("src/main/proto")
197+
}
198+
199+
kotlin {
200+
include("squareup.cash.hello.GreeterService")
201+
rpcCallStyle = "suspending"
202+
rpcRole = "server"
203+
singleMethodServices = true
204+
}
205+
206+
java {
207+
}
208+
}
209+
```
210+
211+
The above will generate a similar action class, but with a suspending action function
212+
213+
```kotlin
214+
@Singleton
215+
class HelloGrpcAction @Inject internal constructor()
216+
: GreeterServiceHelloBlockingServer, WebAction {
217+
218+
@Unauthorized
219+
override suspend fun Hello(request: HelloRequest): HelloResponse {
220+
return HelloResponse("message")
221+
}
222+
}
223+
```
224+
190225
Creating a gRPC action automatically creates a JSON endpoint with all of the same annotations in the
191226
path defined by the `...BlockingServer`, typically `/<package>.<service name>/<rpc name>`.
192227

docs/coroutines.md

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
Misk Coroutines Rules
2+
3+
Coroutines are cooperative concurrency. If you don't cooperate (ie. suspend regularly), things will not be as efficient
4+
and may potentially deadlock.
5+
6+
Misk allows for per request threading for which coroutines can be used to model concurrent operations. If the coroutines suspend,
7+
then multiple coroutines can be run in parallel on the same thread. If coroutines need to block, they will run serially unless
8+
a `Dispatcher` is provided that allows for running on multiple threads. Both models are supported.
9+
10+
Mixing thread based concurrency primitives, to synchronize coroutines, can result in deadlocks if threads are blocked waiting
11+
Example
12+
```class DangerTest {
13+
@Test
14+
fun threads() {
15+
val latch = CountDownLatch(3)
16+
thread { latch.countDown() }
17+
thread { latch.countDown() }
18+
thread { latch.countDown() }
19+
latch.await()
20+
}
21+
22+
@Test
23+
fun coroutineDeadlock(){
24+
runBlocking {
25+
val latch = CountDownLatch(3)
26+
async { latch.countDown() }
27+
async { latch.countDown() }
28+
async { latch.countDown() }
29+
latch.await()
30+
}
31+
}
32+
}```
33+
34+
35+
When an action is declared with the `suspend` modifier, it'll be called with a `Dispatcher` that has a single backing
36+
thread (`runBlocking`). This thread is part of the Jetty Thread Pool and allocated to this specific request, therefore
37+
it is safe to make blocking calls on. This will also take care of request scoped features that are thread local,
38+
such as ActionScoped values, MDC, tracing, etc.
39+
40+
When that function returns, the request and response body streams will be flushed and closed immediately.
41+
The framework will release these resources for you.
42+
43+
Follow structured concurrency best practices, including:
44+
- All coroutines should be in child scopes for the incoming request scope.
45+
- Don't use `GlobalScope` or create a new `CoroutineScope()` object.
46+
47+
48+
49+

gradle/libs.versions.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,9 @@ kotlinStdLibJdk8 = { module = "org.jetbrains.kotlin:kotlin-stdlib-jdk8", version
160160
kotlinTest = { module = "org.jetbrains.kotlin:kotlin-test", version.ref = "kotlin" }
161161
kotlinXHtml = { module = "org.jetbrains.kotlinx:kotlinx-html-jvm", version = "0.11.0" }
162162
kotlinxCoroutinesCore = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-core", version = "1.9.0" }
163+
kotlinxCoroutinesDebug = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-debug", version = "1.9.0" }
164+
kotlinxCoroutinesJdk8 = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-jdk8", version = "1.9.0" }
165+
kotlinxCoroutinesTest = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-test", version = "1.9.0" }
163166
kubernetesClient = { module = "io.kubernetes:client-java", version = "18.0.1" }
164167
kubernetesClientApi = { module = "io.kubernetes:client-java-api", version = "18.0.1" }
165168
launchDarkly = { module = "com.launchdarkly:launchdarkly-java-server-sdk", version = "6.3.0" }

misk-action-scopes/src/main/kotlin/misk/scope/ActionScope.kt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ class ActionScope @Inject internal constructor(
3131
/**
3232
* Wraps a [kotlinx.coroutines.runBlocking] to propagate the current action scope.
3333
*/
34+
@Deprecated(
35+
message = "don't use runBlocking explicitly",
36+
replaceWith = ReplaceWith("use suspending invocation")
37+
)
3438
fun <T> runBlocking(block: suspend CoroutineScope.() -> T): T {
3539
return if (inScope()) {
3640
kotlinx.coroutines.runBlocking(asContextElement(), block)
@@ -44,6 +48,10 @@ class ActionScope @Inject internal constructor(
4448
/**
4549
* Wraps a [kotlinx.coroutines.runBlocking] to propagate the current action scope.
4650
*/
51+
@Deprecated(
52+
message = "don't use runBlocking explicitly",
53+
replaceWith = ReplaceWith("use suspending invocation")
54+
)
4755
fun <T> runBlocking(context: CoroutineContext, block: suspend CoroutineScope.() -> T): T {
4856
return if (inScope()) {
4957
kotlinx.coroutines.runBlocking(context + asContextElement(), block)

misk/api/misk.api

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1683,10 +1683,6 @@ public final class misk/web/actions/StatusAction$ServerStatus {
16831683
public fun toString ()Ljava/lang/String;
16841684
}
16851685

1686-
public final class misk/web/actions/WebActionsKt {
1687-
public static final fun asChain (Lmisk/web/actions/WebAction;Lkotlin/reflect/KFunction;Ljava/util/List;Ljava/util/List;Lmisk/web/HttpCall;)Lmisk/Chain;
1688-
}
1689-
16901686
public abstract interface class misk/web/concurrencylimits/ConcurrencyLimiterFactory {
16911687
public abstract fun create (Lmisk/Action;)Lcom/netflix/concurrency/limits/Limiter;
16921688
}

misk/build.gradle.kts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ dependencies {
5353
implementation(libs.loggingApi)
5454
implementation(libs.kotlinReflect)
5555
implementation(libs.kotlinStdLibJdk8)
56+
implementation(libs.kotlinxCoroutinesCore)
5657
implementation(libs.moshiAdapters)
5758
implementation(libs.okio)
5859
implementation(libs.openTracingConcurrent)
@@ -78,7 +79,9 @@ dependencies {
7879
testImplementation(libs.guavaTestLib)
7980
testImplementation(libs.junitApi)
8081
testImplementation(libs.junitParams)
82+
testImplementation(libs.kotestAssertions)
8183
testImplementation(libs.kotlinTest)
84+
testImplementation(libs.kotlinxCoroutinesTest)
8285
testImplementation(libs.logbackClassic)
8386
testImplementation(libs.okHttpMockWebServer)
8487
testImplementation(libs.openTracingMock)

misk/src/main/kotlin/misk/grpc/Grpc.kt

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ package misk.grpc
22

33
import com.squareup.wire.MessageSink
44
import com.squareup.wire.MessageSource
5+
import kotlinx.coroutines.channels.ReceiveChannel
6+
import kotlinx.coroutines.channels.SendChannel
57
import java.lang.reflect.ParameterizedType
68
import java.lang.reflect.Type
79
import java.lang.reflect.WildcardType
@@ -10,13 +12,15 @@ import kotlin.reflect.jvm.javaType
1012

1113
/**
1214
* Returns the stream element type, like `MyRequest` if this is `MessageSource<MyRequest>`.
13-
* Returns null if this is not a [MessageSource] or [MessageSink].
15+
* Returns null if this is not a [MessageSource], [MessageSink], [SendChannel] or [ReceiveChannel].
1416
*/
1517
internal fun KType.streamElementType(): Type? {
1618
// Unbox the type parameter.
1719
val parameterizedType = javaType as? ParameterizedType ?: return null
1820
if (parameterizedType.rawType != MessageSource::class.java &&
19-
parameterizedType.rawType != MessageSink::class.java
21+
parameterizedType.rawType != MessageSink::class.java &&
22+
parameterizedType.rawType != SendChannel::class.java &&
23+
parameterizedType.rawType != ReceiveChannel::class.java
2024
) return null
2125
// Remove the wildcard, like 'out MessageSource' (Kotlin) or '? super MessageSource' (Java).
2226
return when (val typeArgument = parameterizedType.actualTypeArguments[0]) {

misk/src/main/kotlin/misk/grpc/GrpcFeatureBinding.kt

Lines changed: 58 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,33 @@ package misk.grpc
22

33
import com.squareup.wire.ProtoAdapter
44
import com.squareup.wire.WireRpc
5+
import jakarta.inject.Inject
6+
import jakarta.inject.Singleton
7+
import kotlinx.coroutines.Dispatchers
8+
import kotlinx.coroutines.channels.BufferOverflow
9+
import kotlinx.coroutines.channels.Channel
10+
import kotlinx.coroutines.channels.Channel.Factory.RENDEZVOUS
511
import misk.Action
612
import misk.web.DispatchMechanism
713
import misk.web.FeatureBinding
814
import misk.web.FeatureBinding.Claimer
915
import misk.web.FeatureBinding.Subject
10-
import misk.web.Grpc
1116
import misk.web.PathPattern
17+
import misk.web.WebConfig
1218
import misk.web.actions.findAnnotationWithOverrides
1319
import misk.web.mediatype.MediaTypes
1420
import java.lang.reflect.Type
15-
import jakarta.inject.Inject
16-
import jakarta.inject.Singleton
21+
import kotlin.coroutines.CoroutineContext
1722

1823
internal class GrpcFeatureBinding(
1924
private val requestAdapter: ProtoAdapter<Any>,
2025
private val responseAdapter: ProtoAdapter<Any>,
2126
private val streamingRequest: Boolean,
22-
private val streamingResponse: Boolean
27+
private val streamingResponse: Boolean,
28+
private val isSuspend: Boolean,
29+
private val grpcMessageSourceChannelContext: CoroutineContext
2330
) : FeatureBinding {
31+
2432
override fun beforeCall(subject: Subject) {
2533
val requestBody = subject.takeRequestBody()
2634
val messageSource = GrpcMessageSource(
@@ -29,7 +37,19 @@ internal class GrpcFeatureBinding(
2937
)
3038

3139
if (streamingRequest) {
32-
subject.setParameter(0, messageSource)
40+
val param: Any = if (isSuspend) {
41+
GrpcMessageSourceChannel(
42+
channel = Channel(
43+
capacity = RENDEZVOUS,
44+
onBufferOverflow = BufferOverflow.SUSPEND,
45+
),
46+
source = messageSource,
47+
coroutineContext = grpcMessageSourceChannelContext,
48+
)
49+
} else {
50+
messageSource
51+
}
52+
subject.setParameter(0, param)
3353
} else {
3454
val request = messageSource.read()!!
3555
subject.setParameter(0, request)
@@ -38,9 +58,20 @@ internal class GrpcFeatureBinding(
3858
if (streamingResponse) {
3959
val responseBody = subject.takeResponseBody()
4060
val messageSink = GrpcMessageSink(responseBody, responseAdapter, grpcEncoding = "identity")
61+
val param: Any = if (isSuspend) {
62+
GrpcMessageSinkChannel(
63+
channel = Channel(
64+
capacity = RENDEZVOUS,
65+
onBufferOverflow = BufferOverflow.SUSPEND,
66+
),
67+
sink = messageSink,
68+
)
69+
} else {
70+
messageSink
71+
}
4172

4273
// It's a streaming response, give the call a SendChannel to write to.
43-
subject.setParameter(1, messageSink)
74+
subject.setParameter(1, param)
4475
setResponseHeaders(subject)
4576
}
4677
}
@@ -72,7 +103,19 @@ internal class GrpcFeatureBinding(
72103
}
73104

74105
@Singleton
75-
class Factory @Inject internal constructor() : FeatureBinding.Factory {
106+
class Factory @Inject internal constructor(
107+
webConfig: WebConfig
108+
) : FeatureBinding.Factory {
109+
110+
// This dispatcher is sized to the jetty thread pool size to make sure that
111+
// no requests that are currently scheduled on a jetty thread are ever blocked
112+
// from reading a streaming request
113+
private val grpcMessageSourceChannelDispatcher =
114+
Dispatchers.IO.limitedParallelism(
115+
parallelism = webConfig.jetty_max_thread_pool_size,
116+
name = "GrpcMessageSourceChannel.bridgeFromSource"
117+
)
118+
76119
override fun create(
77120
action: Action,
78121
pathPattern: PathPattern,
@@ -96,30 +139,35 @@ internal class GrpcFeatureBinding(
96139
val responseAdapter = if (action.parameters.size == 2) {
97140
claimer.claimParameter(1)
98141
val responseType: Type = action.parameters[1].type.streamElementType()
99-
?: error("@Grpc function's second parameter should be a MessageSource: $action")
142+
?: error("@Grpc function's second parameter should be a MessageSink(blocking) or SendChannel(suspending): $action")
100143
@Suppress("UNCHECKED_CAST") // Assume it's a proto type.
101144
ProtoAdapter.get(responseType as Class<Any>)
102145
} else {
103146
claimer.claimReturnValue()
104147
@Suppress("UNCHECKED_CAST") // Assume it's a proto type.
105148
ProtoAdapter.get(wireAnnotation.responseAdapter) as ProtoAdapter<Any>
106149
}
150+
val isSuspending = action.function.isSuspend
107151

108152
return if (streamingRequestType != null) {
109153
@Suppress("UNCHECKED_CAST") // Assume it's a proto type.
110154
GrpcFeatureBinding(
111155
requestAdapter = ProtoAdapter.get(streamingRequestType as Class<Any>),
112156
responseAdapter = responseAdapter,
113157
streamingRequest = true,
114-
streamingResponse = streamingResponse
158+
streamingResponse = streamingResponse,
159+
isSuspend = isSuspending,
160+
grpcMessageSourceChannelContext = grpcMessageSourceChannelDispatcher,
115161
)
116162
} else {
117163
@Suppress("UNCHECKED_CAST") // Assume it's a proto type.
118164
GrpcFeatureBinding(
119165
requestAdapter = ProtoAdapter.get(wireAnnotation.requestAdapter) as ProtoAdapter<Any>,
120166
responseAdapter = responseAdapter,
121167
streamingRequest = false,
122-
streamingResponse = streamingResponse
168+
streamingResponse = streamingResponse,
169+
isSuspend = isSuspending,
170+
grpcMessageSourceChannelContext = grpcMessageSourceChannelDispatcher,
123171
)
124172
}
125173
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
package misk.grpc
2+
3+
import kotlinx.coroutines.channels.Channel
4+
import kotlinx.coroutines.channels.SendChannel
5+
import kotlinx.coroutines.channels.consumeEach
6+
import kotlinx.coroutines.runInterruptible
7+
8+
/**
9+
* Bridges a [GrpcMessageSink] to a [Channel].
10+
*
11+
* This is the primary mechanism for suspending gRPC calls to handle
12+
* Server response streaming. The [GrpcMessageSinkChannel] can be passed
13+
* to the gRPC Action function as a SendChannel to write responses to.
14+
*
15+
*/
16+
internal class GrpcMessageSinkChannel<T : Any>(
17+
private val channel: Channel<T>,
18+
private val sink: GrpcMessageSink<T>,
19+
) : SendChannel<T> by channel {
20+
21+
/**
22+
* Bridges the channel to the sink.
23+
*
24+
* This will read from the [channel] and write the messages to the [sink]
25+
* until the channel is closed for sending.
26+
*/
27+
suspend fun bridgeToSink() =
28+
channel.consumeEach {
29+
runInterruptible {
30+
sink.write(it)
31+
}
32+
}
33+
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
package misk.grpc
2+
3+
import kotlinx.coroutines.channels.Channel
4+
import kotlinx.coroutines.channels.ReceiveChannel
5+
import kotlinx.coroutines.runInterruptible
6+
import kotlin.coroutines.CoroutineContext
7+
8+
/**
9+
* Bridges a [GrpcMessageSource] to a [Channel].
10+
*
11+
* This is the primary mechanism for suspending gRPC calls to handle
12+
* Client request streaming. The [GrpcMessageSourceChannel] can be passed
13+
* to the gRPC Action function as a ReceiveChannel to read requests from.
14+
*
15+
*/
16+
internal class GrpcMessageSourceChannel<T : Any>(
17+
private val channel: Channel<T>,
18+
private val source: GrpcMessageSource<T>,
19+
private val coroutineContext: CoroutineContext,
20+
) : ReceiveChannel<T> by channel {
21+
22+
/**
23+
* Bridges the source to the channel.
24+
*
25+
* This will read from the [source] and send the messages to the [channel]
26+
* until the [source] is exhausted.
27+
*/
28+
suspend fun bridgeFromSource() {
29+
try {
30+
while (true) {
31+
val request = runInterruptible(coroutineContext) {
32+
source.read()
33+
} ?: break
34+
channel.send(request)
35+
}
36+
} finally {
37+
channel.close()
38+
}
39+
}
40+
41+
}

0 commit comments

Comments
 (0)