Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 54 additions & 51 deletions rest/jetty/src/main/scala/io/udash/rest/jetty/JettyRestClient.scala
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
package io.udash
package rest.jetty

import com.avsystem.commons._
import com.avsystem.commons.*
import com.avsystem.commons.annotation.explicitGenerics
import io.udash.rest.raw._
import io.udash.rest.raw.*
import io.udash.utils.URLEncoder
import monix.eval.Task
import org.eclipse.jetty.client.{BufferingResponseListener, BytesRequestContent, HttpClient, Result, StringRequestContent}
import monix.execution.Callback
import org.eclipse.jetty.client.*
import org.eclipse.jetty.http.{HttpCookie, HttpHeader, MimeTypes}

import java.nio.charset.Charset
import scala.concurrent.duration._
import scala.util.{Failure, Success}
import scala.concurrent.CancellationException
import scala.concurrent.duration.*

object JettyRestClient {
final val DefaultMaxResponseLength = 2 * 1024 * 1024
Expand All @@ -31,55 +32,57 @@ object JettyRestClient {
maxResponseLength: Int = DefaultMaxResponseLength,
timeout: Duration = DefaultTimeout
): RawRest.HandleRequest =
request => Task.async { callback =>
val path = baseUrl + PlainValue.encodePath(request.parameters.path)
val httpReq = client.newRequest(baseUrl).method(request.method.name)
request => Task(client.newRequest(baseUrl).method(request.method.name)).flatMap { httpReq =>
Task.async { (callback: Callback[Throwable, RestResponse]) =>
val path = baseUrl + PlainValue.encodePath(request.parameters.path)

httpReq.path(path)
request.parameters.query.entries.foreach {
case (name, PlainValue(value)) => httpReq.param(name, value)
}
request.parameters.headers.entries.foreach {
case (name, PlainValue(value)) => httpReq.headers(headers => headers.add(name, value))
}
request.parameters.cookies.entries.foreach {
case (name, PlainValue(value)) => httpReq.cookie(HttpCookie.build(
URLEncoder.encode(name, spaceAsPlus = true), URLEncoder.encode(value, spaceAsPlus = true)).build())
}

request.body match {
case HttpBody.Empty =>
case tb: HttpBody.Textual =>
httpReq.body(new StringRequestContent(tb.contentType, tb.content, Charset.forName(tb.charset)))
case bb: HttpBody.Binary =>
httpReq.body(new BytesRequestContent(bb.contentType, bb.bytes))
}
httpReq.path(path)
request.parameters.query.entries.foreach {
case (name, PlainValue(value)) => httpReq.param(name, value)
}
request.parameters.headers.entries.foreach {
case (name, PlainValue(value)) => httpReq.headers(headers => headers.add(name, value))
}
request.parameters.cookies.entries.foreach {
case (name, PlainValue(value)) => httpReq.cookie(HttpCookie.build(
URLEncoder.encode(name, spaceAsPlus = true), URLEncoder.encode(value, spaceAsPlus = true)).build())
}

timeout match {
case fd: FiniteDuration => httpReq.timeout(fd.length, fd.unit)
case _ =>
}
request.body match {
case HttpBody.Empty =>
case tb: HttpBody.Textual =>
httpReq.body(new StringRequestContent(tb.contentType, tb.content, Charset.forName(tb.charset)))
case bb: HttpBody.Binary =>
httpReq.body(new BytesRequestContent(bb.contentType, bb.bytes))
}

httpReq.send(new BufferingResponseListener(maxResponseLength) {
override def onComplete(result: Result): Unit =
if (result.isSucceeded) {
val httpResp = result.getResponse
val contentTypeOpt = httpResp.getHeaders.get(HttpHeader.CONTENT_TYPE).opt
val charsetOpt = contentTypeOpt.map(MimeTypes.getCharsetFromContentType)
val body = (contentTypeOpt, charsetOpt) match {
case (Opt(contentType), Opt(charset)) =>
HttpBody.textual(getContentAsString, MimeTypes.getContentTypeWithoutCharset(contentType), charset)
case (Opt(contentType), Opt.Empty) =>
HttpBody.binary(getContent, contentType)
case _ =>
HttpBody.Empty
}
val headers = httpResp.getHeaders.asScala.iterator.map(h => (h.getName, PlainValue(h.getValue))).toList
val response = RestResponse(httpResp.getStatus, IMapping(headers), body)
callback(Success(response))
} else {
callback(Failure(result.getFailure))
timeout match {
case fd: FiniteDuration => httpReq.timeout(fd.length, fd.unit)
case _ =>
}
})

httpReq.send(new BufferingResponseListener(maxResponseLength) {
override def onComplete(result: Result): Unit =
if (result.isSucceeded) {
val httpResp = result.getResponse
val contentTypeOpt = httpResp.getHeaders.get(HttpHeader.CONTENT_TYPE).opt
val charsetOpt = contentTypeOpt.map(MimeTypes.getCharsetFromContentType)
val body = (contentTypeOpt, charsetOpt) match {
case (Opt(contentType), Opt(charset)) =>
HttpBody.textual(getContentAsString, MimeTypes.getContentTypeWithoutCharset(contentType), charset)
case (Opt(contentType), Opt.Empty) =>
HttpBody.binary(getContent, contentType)
case _ =>
HttpBody.Empty
}
val headers = httpResp.getHeaders.asScala.iterator.map(h => (h.getName, PlainValue(h.getValue))).toList
val response = RestResponse(httpResp.getStatus, IMapping(headers), body)
callback(Success(response))
} else {
callback(Failure(result.getFailure))
}
})
}
.doOnCancel(Task(httpReq.abort(new CancellationException("Request cancelled"))))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package io.udash.rest.jetty

import com.avsystem.commons.misc.ScalaDurationExtensions.durationIntOps
import com.avsystem.commons.universalOps
import io.udash.rest.jetty.CloseStaleJettyConnectionsOnMonixTimeout.RestApiWithNeverCounter
import io.udash.rest.{DefaultRestApiCompanion, GET, RestServlet}
import monix.eval.Task
import monix.execution.atomic.Atomic
import org.eclipse.jetty.client.HttpClient
import org.eclipse.jetty.ee8.servlet.{ServletContextHandler, ServletHolder}
import org.eclipse.jetty.server.{NetworkConnector, Server}
import org.scalatest.funsuite.AsyncFunSuite

import java.net.InetSocketAddress
import scala.concurrent.Future
import scala.concurrent.duration.{FiniteDuration, IntMult}

final class CloseStaleJettyConnectionsOnMonixTimeout extends AsyncFunSuite {

test("close connection on monix task timeout") {
import monix.execution.Scheduler.Implicits.global

val MaxConnections: Int = 1 // to timeout quickly
val Connections: Int = 10 // > MaxConnections
val RequestTimeout: FiniteDuration = 1.hour // no timeout
val CallTimeout: FiniteDuration = 300.millis


val server = new Server(new InetSocketAddress("localhost", 0)) {
setHandler(
new ServletContextHandler().setup(
_.addServlet(
new ServletHolder(
RestServlet[RestApiWithNeverCounter](RestApiWithNeverCounter.Impl)
),
"/*",
)
)
)
start()
}

val httpClient = new HttpClient() {
setMaxConnectionsPerDestination(MaxConnections)
setIdleTimeout(RequestTimeout.toMillis)
start()
}

val client = JettyRestClient[RestApiWithNeverCounter](
client = httpClient,
baseUri = server.getConnectors.head |> { case connector: NetworkConnector => s"http://${connector.getHost}:${connector.getLocalPort}" },
maxResponseLength = Int.MaxValue, // to avoid unnecessary logs
timeout = RequestTimeout,
)

Task
.traverse(List.range(0, Connections))(_ => Task.fromFuture(client.neverGet).timeout(CallTimeout).failed)
.timeoutTo(Connections * CallTimeout + 500.millis, Task(fail("All connections should have been closed"))) // + 500 millis just in case
.map(_ => assert(RestApiWithNeverCounter.Impl.counter.get() == Connections)) // neverGet should be called Connections times
.guarantee(Task {
server.stop()
httpClient.stop()
})
.runToFuture
}
}

object CloseStaleJettyConnectionsOnMonixTimeout {
sealed trait RestApiWithNeverCounter {
final val counter = Atomic(0)
@GET def neverGet: Future[Unit]
}

object RestApiWithNeverCounter extends DefaultRestApiCompanion[RestApiWithNeverCounter] {
final val Impl: RestApiWithNeverCounter = new RestApiWithNeverCounter {
override def neverGet: Future[Unit] = {
counter.increment()
Future.never
}
}
}
}
10 changes: 5 additions & 5 deletions rest/src/test/scala/io/udash/rest/RestTestApi.scala
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
package io.udash
package rest

import com.avsystem.commons._
import com.avsystem.commons.*
import com.avsystem.commons.misc.{AbstractValueEnum, AbstractValueEnumCompanion, EnumCtx}
import com.avsystem.commons.rpc.AsRawReal
import com.avsystem.commons.serialization.*
import com.avsystem.commons.serialization.json.JsonStringOutput
import com.avsystem.commons.serialization.{GenCodec, HasPolyGenCodec, flatten, name, whenAbsent}
import io.udash.rest.openapi.adjusters._
import io.udash.rest.openapi.{Header => OASHeader, _}
import io.udash.rest.raw._
import io.udash.rest.openapi.adjusters.*
import io.udash.rest.openapi.{Header as OASHeader, *}
import io.udash.rest.raw.*
import monix.execution.{FutureUtils, Scheduler}

import scala.concurrent.Future
Expand Down