Skip to content

Commit 6423bb3

Browse files
committed
Logging in servlet for client-side stream cancellation
1 parent 96d9143 commit 6423bb3

File tree

2 files changed

+41
-27
lines changed

2 files changed

+41
-27
lines changed

build.sbt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ name := "udash"
99
Global / excludeLintKeys ++= Set(ideOutputDirectory, ideSkipProject)
1010

1111
inThisBuild(Seq(
12-
version := "0.9.0-SNAPSHOT",
12+
version := "0.18.0-SNAPSHOT",
1313
organization := "io.udash",
1414
resolvers += Resolver.defaultLocal,
1515
))

rest/.jvm/src/main/scala/io/udash/rest/RestServlet.scala

Lines changed: 40 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,16 @@ package rest
33

44
import com.avsystem.commons.*
55
import com.avsystem.commons.annotation.explicitGenerics
6-
import com.typesafe.scalalogging.LazyLogging
6+
import com.typesafe.scalalogging.{LazyLogging, Logger as ScalaLogger}
77
import io.udash.rest.RestServlet.*
88
import io.udash.rest.raw.*
99
import io.udash.utils.URLEncoder
1010
import monix.eval.Task
1111
import monix.execution.Scheduler
1212
import monix.reactive.{Consumer, Observable}
13+
import org.slf4j.{Logger, LoggerFactory}
1314

14-
import java.io.ByteArrayOutputStream
15+
import java.io.{ByteArrayOutputStream, EOFException}
1516
import java.util.concurrent.atomic.AtomicBoolean
1617
import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse}
1718
import javax.servlet.{AsyncEvent, AsyncListener}
@@ -55,12 +56,16 @@ class RestServlet(
5556
handleTimeout: FiniteDuration = DefaultHandleTimeout,
5657
maxPayloadSize: Long = DefaultMaxPayloadSize,
5758
defaultStreamingBatchSize: Int = DefaultStreamingBatchSize,
59+
customLogger: OptArg[Logger] = OptArg.Empty,
5860
)(implicit
5961
scheduler: Scheduler
6062
) extends HttpServlet with LazyLogging {
6163

6264
import RestServlet.*
6365

66+
override protected lazy val logger: ScalaLogger =
67+
ScalaLogger(customLogger.getOrElse(LoggerFactory.getLogger(getClass.getName)))
68+
6469
override def service(request: HttpServletRequest, response: HttpServletResponse): Unit = {
6570
val asyncContext = request.startAsync()
6671
val completed = new AtomicBoolean(false)
@@ -76,25 +81,28 @@ class RestServlet(
7681

7782
// readRequest must execute in Jetty thread but we want exceptions to be handled uniformly, hence the Try
7883
val udashRequest = Try(readRequest(request))
79-
val cancelable = Task.defer(handleRequest(udashRequest.get)).flatMap { rr =>
80-
Task(setResponseHeaders(response, rr.code, rr.headers)) >>
81-
writeResponseBody(response, rr)
82-
}.executeAsync.runAsync {
83-
case Right(_) =>
84-
asyncContext.complete()
85-
case Left(e: HttpErrorException) =>
86-
completeWith(writeResponse(response, e.toResponse))
87-
case Left(e) =>
88-
logger.error("Failed to handle REST request", e)
89-
completeWith(writeFailure(response, e.getMessage.opt))
90-
}
84+
val cancelable =
85+
(for {
86+
restRequest <- Task.fromTry(udashRequest)
87+
restResponse <- handleRequest(restRequest)
88+
_ <- Task(setResponseHeaders(response, restResponse.code, restResponse.headers))
89+
_ <- writeResponseBody(response, restResponse)
90+
} yield ()).executeAsync.runAsync {
91+
case Right(_) =>
92+
asyncContext.complete()
93+
case Left(e: HttpErrorException) =>
94+
completeWith(writeResponse(response, e.toResponse))
95+
case Left(e) =>
96+
logger.error("Failed to handle REST request", e)
97+
completeWith(writeFailure(response, e.getMessage.opt))
98+
}
9199

92100
asyncContext.setTimeout(handleTimeout.toMillis)
93101
asyncContext.addListener(new AsyncListener {
94102
def onComplete(event: AsyncEvent): Unit = ()
95103
def onTimeout(event: AsyncEvent): Unit = {
96104
cancelable.cancel()
97-
completeWith(writeFailure(response, Opt(s"server operation timed out after $handleTimeout")))
105+
completeWith(writeFailure(response, s"server operation timed out after $handleTimeout".opt))
98106
}
99107
def onError(event: AsyncEvent): Unit = ()
100108
def onStartAsync(event: AsyncEvent): Unit = ()
@@ -117,13 +125,13 @@ class RestServlet(
117125

118126
private def writeNonEmptyStreamedBody(
119127
response: HttpServletResponse,
120-
body: StreamedBody.NonEmpty,
128+
responseBody: StreamedBody.NonEmpty,
121129
): Task[Unit] = Task.defer {
122130
// The Content-Length header is intentionally omitted for streams.
123131
// This signals to the client that the response body size is not predetermined and will be streamed.
124132
// Clients implementing the streaming part of the REST interface contract MUST be prepared
125133
// to handle responses without Content-Length by reading data incrementally until the stream completes.
126-
body match {
134+
responseBody match {
127135
case single: StreamedBody.Single =>
128136
Task.eval(writeNonEmptyBody(response, single.body))
129137
case binary: StreamedBody.RawBinary =>
@@ -158,17 +166,23 @@ class RestServlet(
158166
})
159167
.map(_ => response.getOutputStream.write("]".getBytes(jsonList.charset)))
160168
}
161-
}.onErrorHandle { e =>
162-
// When an error occurs during streaming, we immediately close the connection rather than
163-
// attempting to send an error response. This is intentional because:
164-
// The client has likely already received and started processing partial data
165-
// for structured formats (like JSON arrays), the stream is now in an invalid state
166-
logger.error("Failure during streaming REST response", e)
167-
response.getOutputStream.close()
169+
}.onErrorHandle {
170+
case _: EOFException =>
171+
logger.warn("Request was cancelled by the client during streaming REST response")
172+
case ex =>
173+
// When an error occurs during streaming, we immediately close the connection rather than
174+
// attempting to send an error response. This is intentional because:
175+
// The client has likely already received and started processing partial data
176+
// for structured formats (like JSON arrays), the stream is now in an invalid state
177+
logger.error("Failure during streaming REST response", ex)
178+
response.getOutputStream.close()
168179
}
169180

170-
private def writeResponseBody(response: HttpServletResponse, rr: AbstractRestResponse): Task[Unit] =
171-
rr match {
181+
private def writeResponseBody(
182+
response: HttpServletResponse,
183+
restResponse: AbstractRestResponse,
184+
): Task[Unit] =
185+
restResponse match {
172186
case resp: RestResponse =>
173187
resp.body match {
174188
case HttpBody.Empty => Task.unit

0 commit comments

Comments
 (0)