Skip to content

Commit ce1da39

Browse files
authored
feat: Run user logic on virtual threads (#2283)
Bumps Akka 2.10.3 for access to upstream feature.
1 parent 7da961f commit ce1da39

File tree

9 files changed

+121
-68
lines changed

9 files changed

+121
-68
lines changed

project/Dependencies.scala

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ object Dependencies {
2020

2121
val ProtobufVersion = akka.grpc.gen.BuildInfo.googleProtobufVersion
2222

23-
val AkkaVersion = "2.10.2"
23+
val AkkaVersion = "2.10.3"
2424
val AkkaHttpVersion = "10.7.0" // Note: should at least the Akka HTTP version required by Akka gRPC
2525
val ScalaTestVersion = "3.2.14"
2626
// https://github.com/akka/akka/blob/main/project/Dependencies.scala#L31
@@ -163,14 +163,10 @@ object Dependencies {
163163
addSbtPlugin(sbtProtoc),
164164
addSbtPlugin(akkaGrpc))
165165

166-
lazy val excludeTheseDependencies: Seq[ExclusionRule] = Seq(
167-
// exclusion rules can be added here
168-
)
169-
170166
def akkaDependency(name: String, excludeThese: ExclusionRule*) =
171-
("com.typesafe.akka" %% name % AkkaVersion).excludeAll((excludeTheseDependencies ++ excludeThese): _*)
167+
("com.typesafe.akka" %% name % AkkaVersion).excludeAll(excludeThese: _*)
172168

173169
def akkaHttpDependency(name: String, excludeThese: ExclusionRule*) =
174-
("com.typesafe.akka" %% name % AkkaHttpVersion).excludeAll((excludeTheseDependencies ++ excludeThese): _*)
170+
("com.typesafe.akka" %% name % AkkaHttpVersion).excludeAll(excludeThese: _*)
175171

176172
}

sdk/java-sdk-protobuf/src/main/resources/reference.conf

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,4 +100,12 @@ kalix {
100100
collector-endpoint = ${?COLLECTOR_ENDPOINT}
101101
}
102102
}
103+
104+
sdk-dispatcher {
105+
executor = "virtual-thread-executor"
106+
virtual-thread-executor {
107+
# if not on JDK 21
108+
fallback="fork-join-executor"
109+
}
110+
}
103111
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
/*
2+
* Copyright (C) 2021-2024 Lightbend Inc. <https://www.lightbend.com>
3+
*/
4+
5+
package kalix.javasdk.impl
6+
7+
import akka.actor.ActorSystem
8+
import akka.annotation.InternalApi
9+
import akka.stream.ActorAttributes
10+
import akka.stream.Attributes
11+
12+
import scala.concurrent.ExecutionContext
13+
14+
/**
15+
* INTERNAL API
16+
*/
17+
@InternalApi
18+
object SdkExecutionContext {
19+
val DispatcherName: String = "kalix.sdk-dispatcher"
20+
def apply(system: ActorSystem): ExecutionContext = system.dispatchers.lookup(DispatcherName)
21+
22+
val streamDispatcher: Attributes = ActorAttributes.dispatcher(DispatcherName)
23+
24+
}

sdk/java-sdk-protobuf/src/main/scala/kalix/javasdk/impl/action/ActionsImpl.scala

Lines changed: 77 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import kalix.protocol.component.{ Failure, MetadataEntry }
2525
import org.slf4j.{ Logger, LoggerFactory, MDC }
2626

2727
import java.util.Optional
28+
import scala.concurrent.ExecutionContext
2829
import scala.concurrent.Future
2930
import scala.jdk.CollectionConverters.SeqHasAsJava
3031
import scala.jdk.OptionConverters._
@@ -101,9 +102,10 @@ private[javasdk] final class ActionsImpl(_system: ActorSystem, services: Map[Str
101102

102103
import ActionsImpl._
103104
import _system.dispatcher
104-
implicit val system: ActorSystem = _system
105+
private implicit val system: ActorSystem = _system
106+
private val sdkEc: ExecutionContext = SdkExecutionContext(system)
105107
private val telemetry = Telemetry(system)
106-
lazy val telemetries: Map[String, Instrumentation] = services.values.map { s =>
108+
private lazy val telemetries: Map[String, Instrumentation] = services.values.map { s =>
107109
(s.serviceName, telemetry.traceInstrumentation(s.serviceName, ActionCategory))
108110
}.toMap
109111

@@ -176,24 +178,29 @@ private[javasdk] final class ActionsImpl(_system: ActorSystem, services: Map[Str
176178
services.get(in.serviceName) match {
177179
case Some(service) =>
178180
val span = telemetries(service.serviceName).buildSpan(service, in)
179-
span.foreach(s => MDC.put(Telemetry.TRACE_ID, s.getSpanContext.getTraceId))
181+
180182
val fut =
181-
try {
182-
val context = createContext(in, service.messageCodec, span.map(_.getSpanContext), service.serviceName)
183-
val decodedPayload = service.messageCodec.decodeMessage(
184-
in.payload.getOrElse(throw new IllegalArgumentException("No command payload")))
185-
val effect = service.factory
186-
.create(context)
187-
.handleUnary(in.name, MessageEnvelope.of(decodedPayload, context.metadata()), context)
188-
effectToResponse(service, in, effect, service.messageCodec)
189-
} catch {
190-
case NonFatal(ex) =>
191-
// command handler threw an "unexpected" error
192-
span.foreach(_.end())
193-
Future.successful(handleUnexpectedException(service, in, ex))
194-
} finally {
195-
MDC.remove(Telemetry.TRACE_ID)
196-
}
183+
// Note: invocation in future to guarantee create and invocation is running on sdk dispatcher with virtual thread support
184+
Future {
185+
try {
186+
span.foreach(s => MDC.put(Telemetry.TRACE_ID, s.getSpanContext.getTraceId))
187+
val context = createContext(in, service.messageCodec, span.map(_.getSpanContext), service.serviceName)
188+
val decodedPayload = service.messageCodec.decodeMessage(
189+
in.payload.getOrElse(throw new IllegalArgumentException("No command payload")))
190+
val effect = service.factory
191+
.create(context)
192+
.handleUnary(in.name, MessageEnvelope.of(decodedPayload, context.metadata()), context)
193+
effectToResponse(service, in, effect, service.messageCodec)
194+
} catch {
195+
case NonFatal(ex) =>
196+
// command handler threw an "unexpected" error
197+
span.foreach(_.end())
198+
Future.successful(handleUnexpectedException(service, in, ex))
199+
} finally {
200+
MDC.remove(Telemetry.TRACE_ID)
201+
}
202+
}(sdkEc).flatten
203+
197204
fut.andThen { case _ =>
198205
span.foreach(_.end())
199206
}
@@ -246,7 +253,7 @@ private[javasdk] final class ActionsImpl(_system: ActorSystem, services: Map[Str
246253
Future.successful(
247254
ActionResponse(ActionResponse.Response.Failure(Failure(0, "Unknown service: " + call.serviceName))))
248255
}
249-
}
256+
}(sdkEc)
250257

251258
/**
252259
* Handle a streamed out command. The input command will contain the service name, command name, request metadata and
@@ -258,25 +265,32 @@ private[javasdk] final class ActionsImpl(_system: ActorSystem, services: Map[Str
258265
override def handleStreamedOut(in: ActionCommand): Source[ActionResponse, NotUsed] =
259266
services.get(in.serviceName) match {
260267
case Some(service) =>
261-
try {
262-
val context = createContext(in, service.messageCodec, None, service.serviceName)
263-
val decodedPayload = service.messageCodec.decodeMessage(
264-
in.payload.getOrElse(throw new IllegalArgumentException("No command payload")))
265-
service.factory
266-
.create(context)
267-
.handleStreamedOut(in.name, MessageEnvelope.of(decodedPayload, context.metadata()), context)
268-
.asScala
269-
.mapAsync(1)(effect => effectToResponse(service, in, effect, service.messageCodec))
270-
.recover { case NonFatal(ex) =>
271-
// user stream failed with an "unexpected" error
272-
handleUnexpectedException(service, in, ex)
268+
// Note: invocation in future to guarantee create and invocation is running on sdk dispatcher with virtual thread support
269+
Source
270+
.futureSource(Future {
271+
try {
272+
val context = createContext(in, service.messageCodec, None, service.serviceName)
273+
val decodedPayload = service.messageCodec.decodeMessage(
274+
in.payload.getOrElse(throw new IllegalArgumentException("No command payload")))
275+
service.factory
276+
.create(context)
277+
.handleStreamedOut(in.name, MessageEnvelope.of(decodedPayload, context.metadata()), context)
278+
.asScala
279+
.mapAsync(1)(effect => effectToResponse(service, in, effect, service.messageCodec))
280+
.recover { case NonFatal(ex) =>
281+
// user stream failed with an "unexpected" error
282+
handleUnexpectedException(service, in, ex)
283+
}
284+
// run the stream itself on the virtual thread dispatcher in case the user blocks in stream
285+
.addAttributes(SdkExecutionContext.streamDispatcher)
286+
} catch {
287+
case NonFatal(ex) =>
288+
// command handler threw an "unexpected" error
289+
Source.single(handleUnexpectedException(service, in, ex))
273290
}
274-
.async
275-
} catch {
276-
case NonFatal(ex) =>
277-
// command handler threw an "unexpected" error
278-
Source.single(handleUnexpectedException(service, in, ex))
279-
}
291+
}(sdkEc))
292+
.mapMaterializedValue(_ => NotUsed)
293+
280294
case None =>
281295
Source.single(ActionResponse(ActionResponse.Response.Failure(Failure(0, "Unknown service: " + in.serviceName))))
282296
}
@@ -303,25 +317,32 @@ private[javasdk] final class ActionsImpl(_system: ActorSystem, services: Map[Str
303317
case (Seq(call), messages) =>
304318
services.get(call.serviceName) match {
305319
case Some(service) =>
320+
// Note: invocation in future to guarantee create and invocation is running on sdk dispatcher with virtual thread support
306321
try {
307-
val context = createContext(call, service.messageCodec, None, service.serviceName)
308-
service.factory
309-
.create(context)
310-
.handleStreamed(
311-
call.name,
312-
messages.map { message =>
313-
val metadata = MetadataImpl.of(message.metadata.map(_.entries.toVector).getOrElse(Nil))
314-
val decodedPayload = service.messageCodec.decodeMessage(
315-
message.payload.getOrElse(throw new IllegalArgumentException("No command payload")))
316-
MessageEnvelope.of(decodedPayload, metadata)
317-
}.asJava,
318-
context)
319-
.asScala
320-
.mapAsync(1)(effect => effectToResponse(service, call, effect, service.messageCodec))
321-
.recover { case NonFatal(ex) =>
322-
// user stream failed with an "unexpected" error
323-
handleUnexpectedException(service, call, ex)
324-
}
322+
Source
323+
.futureSource(Future {
324+
val context = createContext(call, service.messageCodec, None, service.serviceName)
325+
service.factory
326+
.create(context)
327+
.handleStreamed(
328+
call.name,
329+
messages.map { message =>
330+
val metadata = MetadataImpl.of(message.metadata.map(_.entries.toVector).getOrElse(Nil))
331+
val decodedPayload = service.messageCodec.decodeMessage(
332+
message.payload.getOrElse(throw new IllegalArgumentException("No command payload")))
333+
MessageEnvelope.of(decodedPayload, metadata)
334+
}.asJava,
335+
context)
336+
.asScala
337+
.mapAsync(1)(effect => effectToResponse(service, call, effect, service.messageCodec))
338+
.recover { case NonFatal(ex) =>
339+
// user stream failed with an "unexpected" error
340+
handleUnexpectedException(service, call, ex)
341+
}
342+
// run the stream itself on the virtual thread dispatcher in case the user blocks in stream
343+
.addAttributes(SdkExecutionContext.streamDispatcher)
344+
}(sdkEc))
345+
.mapMaterializedValue(_ => NotUsed)
325346
} catch {
326347
case NonFatal(ex) =>
327348
// command handler threw an "unexpected" error

sdk/java-sdk-protobuf/src/main/scala/kalix/javasdk/impl/eventsourcedentity/EventSourcedEntitiesImpl.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ final class EventSourcedEntitiesImpl(
135135
EventSourcedStreamOut(OutFailure(Failure(description = s"Unexpected failure [$correlationId]")))
136136
}
137137
}
138+
.withAttributes(SdkExecutionContext.streamDispatcher) // factory instance invoked on this stream
138139
}
139140

140141
private def runEntity(init: EventSourcedInit): Flow[EventSourcedStreamIn, EventSourcedStreamOut, NotUsed] = {
@@ -260,7 +261,7 @@ final class EventSourcedEntitiesImpl(
260261
EventSourcedStreamOut(OutFailure(Failure(description = s"Unexpected failure [$correlationId]")))
261262
}
262263
}
263-
.async
264+
.addAttributes(SdkExecutionContext.streamDispatcher)
264265
}
265266

266267
private class CommandContextImpl(

sdk/java-sdk-protobuf/src/main/scala/kalix/javasdk/impl/replicatedentity/ReplicatedEntitiesImpl.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ final class ReplicatedEntitiesImpl(system: ActorSystem, services: Map[String, Re
8787
ReplicatedEntityStreamOut(Out.Failure(Failure(description = s"Unexpected error [$correlationId]")))
8888
}
8989
}
90-
.async
90+
.addAttributes(SdkExecutionContext.streamDispatcher)
9191

9292
private def runEntity(
9393
init: ReplicatedEntityInit): Flow[ReplicatedEntityStreamIn, ReplicatedEntityStreamOut, NotUsed] = {
@@ -138,6 +138,7 @@ object ReplicatedEntitiesImpl {
138138
val router: ReplicatedEntityRouter[_ <: Object, _ <: Object] = {
139139
val context = new ReplicatedEntityCreationContext(entityId, system)
140140
try {
141+
// Note: not run on sdk execution context
141142
service.factory.create(context)
142143
} finally {
143144
context.deactivate()

sdk/java-sdk-protobuf/src/main/scala/kalix/javasdk/impl/valueentity/ValueEntitiesImpl.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ final class ValueEntitiesImpl(
107107
ValueEntityStreamOut(OutFailure(Failure(description = s"Unexpected error [$correlationId]")))
108108
}
109109
}
110-
.async
110+
.addAttributes(SdkExecutionContext.streamDispatcher)
111111

112112
private def runEntity(init: ValueEntityInit): Flow[ValueEntityStreamIn, ValueEntityStreamOut, NotUsed] = {
113113
val service =

sdk/java-sdk-protobuf/src/main/scala/kalix/javasdk/impl/view/ViewsImpl.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ final class ViewsImpl(system: ActorSystem, _services: Map[String, ViewService])
156156
s"Kalix protocol failure: expected ReceiveEvent message, but got ${other.getClass.getName}"
157157
Source.failed(new RuntimeException(errMsg))
158158
}
159-
.async
159+
.addAttributes(SdkExecutionContext.streamDispatcher)
160160

161161
private final class UpdateContextImpl(
162162
override val viewId: String,

sdk/java-sdk-protobuf/src/main/scala/kalix/javasdk/impl/workflow/WorkflowImpl.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ final class WorkflowImpl(system: ActorSystem, val services: Map[String, Workflow
125125
WorkflowStreamOut(OutFailure(component.Failure(description = s"Unexpected error [$correlationId]")))
126126
}
127127
}
128-
.async
128+
.withAttributes(SdkExecutionContext.streamDispatcher)
129129

130130
private def toRecoverStrategy(messageCodec: MessageCodec)(
131131
recoverStrategy: AbstractWorkflow.RecoverStrategy[_]): RecoverStrategy = {
@@ -169,6 +169,7 @@ final class WorkflowImpl(system: ActorSystem, val services: Map[String, Workflow
169169
WorkflowConfig(workflowTimeout, failoverTo, failoverRecovery, Some(stepConfig), stepConfigs)
170170
}
171171

172+
// Note: called from stream, already on sdk dispatcher
172173
private def runWorkflow(
173174
init: WorkflowEntityInit): (Flow[WorkflowStreamIn, WorkflowStreamOut, NotUsed], WorkflowStreamOut) = {
174175
val service =
@@ -345,6 +346,7 @@ final class WorkflowImpl(system: ActorSystem, val services: Map[String, Workflow
345346
// currently added to satisfy the compiler
346347
Future.successful(WorkflowStreamOut(WorkflowStreamOut.Message.Empty))
347348
}
349+
.addAttributes(SdkExecutionContext.streamDispatcher)
348350

349351
(flow, workflowConfig)
350352
}

0 commit comments

Comments
 (0)