Skip to content
22 changes: 12 additions & 10 deletions examples/src/main/scala/example/complex/ShardManagerApp.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@ import zio._

object ShardManagerApp extends ZIOAppDefault {
def run: Task[Nothing] =
Server.run.provide(
ZLayer.succeed(ManagerConfig.default),
ZLayer.succeed(GrpcConfig.default),
ZLayer.succeed(RedisConfig.default),
redis,
StorageRedis.live, // store data in Redis
PodsHealth.local, // just ping a pod to see if it's alive
GrpcPods.live, // use gRPC protocol
ShardManager.live // Shard Manager logic
)
Server
.run()
.provide(
ZLayer.succeed(ManagerConfig.default),
ZLayer.succeed(GrpcConfig.default),
ZLayer.succeed(RedisConfig.default),
redis,
StorageRedis.live, // store data in Redis
PodsHealth.local, // just ping a pod to see if it's alive
GrpcPods.live, // use gRPC protocol
ShardManager.live // Shard Manager logic
)
}
18 changes: 10 additions & 8 deletions examples/src/main/scala/example/simple/ShardManagerApp.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ import zio._

object ShardManagerApp extends ZIOAppDefault {
def run: Task[Nothing] =
Server.run.provide(
ZLayer.succeed(ManagerConfig.default),
ZLayer.succeed(GrpcConfig.default),
PodsHealth.local, // just ping a pod to see if it's alive
GrpcPods.live, // use gRPC protocol
Storage.memory, // store data in memory
ShardManager.live // Shard Manager logic
)
Server
.run()
.provide(
ZLayer.succeed(ManagerConfig.default),
ZLayer.succeed(GrpcConfig.default),
PodsHealth.local, // just ping a pod to see if it's alive
GrpcPods.live, // use gRPC protocol
Storage.memory, // store data in memory
ShardManager.live // Shard Manager logic
)
}
2 changes: 1 addition & 1 deletion examples/src/test/scala/example/EndToEndSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import scala.util.Try
object EndToEndSpec extends ZIOSpecDefault {

val shardManagerServer: ZLayer[ShardManager with ManagerConfig, Throwable, Unit] =
ZLayer(Server.run.forkDaemon *> ClockLive.sleep(3 seconds).unit)
ZLayer(Server.run().forkDaemon *> ClockLive.sleep(3 seconds).unit)

val container: ZLayer[Any, Nothing, GenericContainer] =
ZLayer.scoped {
Expand Down
64 changes: 64 additions & 0 deletions examples/src/test/scala/example/GrpcAuthExampleSpec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package example

import com.devsisters.shardcake._
import com.devsisters.shardcake.interfaces.{ Pods, Storage }
import io.grpc.{ Metadata, Status }
import scalapb.zio_grpc.{ ZClientInterceptor, ZTransform }
import zio.test._
import zio.{ Config => _, _ }

object GrpcAuthExampleSpec extends ZIOSpecDefault {

private val validAuthenticationKey = "validAuthenticationKey"

private val authKey = Metadata.Key.of("authentication-key", io.grpc.Metadata.ASCII_STRING_MARSHALLER)

private val config = ZLayer.succeed(Config.default.copy(simulateRemotePods = true))

private def grpcConfigLayer(clientAuthKey: String): ULayer[GrpcConfig] =
ZLayer.succeed(
GrpcConfig.default.copy(
clientInterceptors = Seq(
ZClientInterceptor.headersUpdater((_, _, md) => md.put(authKey, clientAuthKey).unit)
),
serverInterceptors = Seq(
ZTransform { requestContext =>
for {
authenticated <- requestContext.metadata.get(authKey).map(_.contains(validAuthenticationKey))
_ <- ZIO.when(!authenticated)(ZIO.fail(Status.UNAUTHENTICATED.asException))
} yield requestContext
}
)
)
)

def spec: Spec[TestEnvironment with Scope, Any] =
suite("GrpcAuthExampleSpec")(
test("auth example for gRPC") {
val podAddress = PodAddress("localhost", 54321)
ZIO.scoped {
for {
_ <- Sharding.registerScoped
podsClient <- ZIO.service[Pods]
invalidPodsClient <- ZIO
.service[Pods]
.provide(
grpcConfigLayer("invalid"),
GrpcPods.live
)
validRequest <- podsClient.ping(podAddress).exit
invalidRequest <- invalidPodsClient.ping(podAddress).exit
} yield assertTrue(validRequest.isSuccess, invalidRequest.isFailure)
}
}
).provide(
ShardManagerClient.local,
Storage.memory,
config,
grpcConfigLayer(validAuthenticationKey),
Sharding.live,
KryoSerialization.live,
GrpcPods.live,
GrpcShardingService.live
)
}
65 changes: 65 additions & 0 deletions examples/src/test/scala/example/ShardManagerAuthExampleSpec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package example

import com.devsisters.shardcake.{ Config, ManagerConfig, Server, ShardManager, ShardManagerClient }
import com.devsisters.shardcake.interfaces.{ Pods, PodsHealth, Storage }
import sttp.client3.SttpBackend
import sttp.client3.asynchttpclient.zio.AsyncHttpClientZioBackend
import sttp.client3.httpclient.zio.ZioWebSocketsStreams
import zio.Clock.ClockLive
import zio.http.{ Header, Middleware }
import zio.test._
import zio.{ Config => _, _ }

object ShardManagerAuthExampleSpec extends ZIOSpecDefault {

val validToken = "validBearerToken"

val shardManagerServerLayer: ZLayer[ManagerConfig, Throwable, Unit] =
ZLayer.makeSome[ManagerConfig, Unit](
ZLayer(
Server
.run(Middleware.bearerAuthZIO(secret => ZIO.succeed(secret.stringValue.equals(validToken))))
.forkDaemon *> ClockLive.sleep(3 seconds).unit
),
Storage.memory,
ShardManager.live,
Pods.noop,
PodsHealth.noop
)

def sttpBackendWithAuthTokenLayer(token: String): ZLayer[Scope, Throwable, SttpBackend[Task, ZioWebSocketsStreams]] =
ZLayer {
val authHeader = Header.Authorization.Bearer(token)
AsyncHttpClientZioBackend.scoped(customizeRequest =
builder => builder.addHeader(authHeader.headerName, authHeader.renderedValue)
)
}

def spec: Spec[TestEnvironment, Any] =
suite("ShardManagerAuthSpec")(
test("auth example for shard manager") {
ZIO.scoped {
for {
validClient <- ZIO
.service[ShardManagerClient]
.provideSome[Config & Scope](
sttpBackendWithAuthTokenLayer(validToken),
ShardManagerClient.live
)
invalidClient <- ZIO
.service[ShardManagerClient]
.provideSome[Config & Scope](
sttpBackendWithAuthTokenLayer("invalid"),
ShardManagerClient.live
)
validRequest <- validClient.getAssignments.exit
invalidRequest <- invalidClient.getAssignments.exit
} yield assertTrue(validRequest.isSuccess, invalidRequest.isFailure)
}
}
).provide(
shardManagerServerLayer,
ZLayer.succeed(Config.default),
ZLayer.succeed(ManagerConfig.default)
)
}
8 changes: 5 additions & 3 deletions manager/src/main/scala/com/devsisters/shardcake/Server.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,17 @@ object Server {
/**
* Start an HTTP server that exposes the Shard Manager GraphQL API
*/
val run: RIO[ShardManager with ManagerConfig, Nothing] =
def run(
httpHandler: HandlerAspect[Any, Unit] = HandlerAspect.identity
): RIO[ShardManager with ManagerConfig, Nothing] =
for {
config <- ZIO.service[ManagerConfig]
interpreter <- (GraphQLApi.api @@ printErrors).interpreter
handlers = QuickAdapter(interpreter).handlers
routes = Routes(
Method.ANY / "health" -> Handler.ok,
Method.ANY / "api" / "graphql" -> handlers.api,
Method.ANY / "ws" / "graphql" -> handlers.webSocket
Method.ANY / "api" / "graphql" -> handlers.api @@ httpHandler,
Method.ANY / "ws" / "graphql" -> handlers.webSocket @@ httpHandler
) @@ Middleware.cors
_ <- ZIO.logInfo(s"Shard Manager server started on port ${config.apiPort}.")
nothing <- ZServer
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package com.devsisters.shardcake

import zio._
import scalapb.zio_grpc.RequestContext
import scalapb.zio_grpc.ZClientInterceptor
import scalapb.zio_grpc.ZTransform

import java.util.concurrent.Executor

Expand All @@ -11,16 +13,18 @@ import java.util.concurrent.Executor
* @param maxInboundMessageSize the maximum message size allowed to be received by the grpc client
* @param executor a custom executor to pass to grpc-java when creating gRPC clients and servers
* @param shutdownTimeout the timeout to wait for the gRPC server to shutdown before forcefully shutting it down
* @param interceptors the interceptors to be used by the gRPC client, e.g for adding tracing or logging
* @param clientInterceptors the interceptors to be used by the gRPC client, e.g for adding tracing or logging
* @param serverInterceptors the interceptors to be used by the gRPC Server, e.g for adding tracing or logging
*/
case class GrpcConfig(
maxInboundMessageSize: Int,
executor: Option[Executor],
shutdownTimeout: Duration,
interceptors: Seq[ZClientInterceptor]
clientInterceptors: Seq[ZClientInterceptor],
serverInterceptors: Seq[ZTransform[RequestContext, RequestContext]]
)

object GrpcConfig {
val default: GrpcConfig =
GrpcConfig(maxInboundMessageSize = 32 * 1024 * 1024, None, 3.seconds, Seq.empty)
GrpcConfig(maxInboundMessageSize = 32 * 1024 * 1024, None, 3.seconds, Seq.empty, Seq.empty)
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class GrpcPods(
}
}

val channel = ZManagedChannel(builder, config.interceptors)
val channel = ZManagedChannel(builder, config.clientInterceptors)
// create a fiber that never ends and keeps the connection alive
for {
_ <- ZIO.logDebug(s"Opening connection to pod $pod")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,29 +84,38 @@ object GrpcShardingService {
val live: ZLayer[Config with Sharding with GrpcConfig, Throwable, Unit] =
ZLayer.scoped[Config with Sharding with GrpcConfig] {
for {
config <- ZIO.service[Config]
grpcConfig <- ZIO.service[GrpcConfig]
sharding <- ZIO.service[Sharding]
builder = grpcConfig.executor match {
case Some(executor) =>
ServerBuilder
.forPort(config.shardingPort)
.executor(executor)
case None =>
ServerBuilder.forPort(config.shardingPort)
}
services <- ServiceList.add(new GrpcShardingService(sharding, config.sendTimeout) {}).bindAll
server: Server = services
.foldLeft(builder) { case (builder0, service) => builder0.addService(service) }
.addService(ProtoReflectionService.newInstance())
.build()
_ <- ZIO.acquireRelease(ZIO.attempt(server.start()))(server =>
ZIO.attemptBlocking {
server.shutdown()
server.awaitTermination(grpcConfig.shutdownTimeout.toMillis, TimeUnit.MILLISECONDS)
server.shutdownNow()
}.ignore
)
config <- ZIO.service[Config]
grpcConfig <- ZIO.service[GrpcConfig]
sharding <- ZIO.service[Sharding]
builder = grpcConfig.executor match {
case Some(executor) =>
ServerBuilder
.forPort(config.shardingPort)
.executor(executor)
case None =>
ServerBuilder.forPort(config.shardingPort)
}
grpcShardingService = new GrpcShardingService(sharding, config.sendTimeout) {}
services <-
ServiceList
.add(
grpcConfig.serverInterceptors
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Originally, I tried the following to add the interceptors:

 services <- ServiceList
            .add(
              grpcConfig.serverInterceptors.foldLeft(new GrpcShardingService(sharding, config.sendTimeout) {}.asGeneric) { case (service, interceptor) =>
                service.transform(interceptor)
              }
            )
            .bindAll

but it did not like that I changed the type of the service from GShardingService[Any, StatusException] to GShardingService[RequestContext, StatusException]. So combining all of the server interceptors into one interceptor was the simplest solution.

.reduceOption(_.andThen(_))
.map(t => grpcShardingService.transform(t))
.getOrElse(grpcShardingService.asGeneric)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The GrpcShardingService depends on the implicit conversion ShardingService.genericBindable to be able to add it to the list. By default the compiler does not know that it needs to use the implicit conversion in the getOrElse call so I needed to call asGeneric

)
.bindAll
server: Server = services
.foldLeft(builder) { case (builder0, service) => builder0.addService(service) }
.addService(ProtoReflectionService.newInstance())
.build()
_ <- ZIO.acquireRelease(ZIO.attempt(server.start()))(server =>
ZIO.attemptBlocking {
server.shutdown()
server.awaitTermination(grpcConfig.shutdownTimeout.toMillis, TimeUnit.MILLISECONDS)
server.shutdownNow()
}.ignore
)
} yield ()
}
}
Loading