Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,18 @@ class CoroutineRouterFunctionDsl {

infix fun RequestPredicate.and(other: String): RequestPredicate = this.and(path(other))

infix fun RequestPredicate.or(other: String): RequestPredicate = this.or(path(other))

infix fun String.and(other: RequestPredicate): RequestPredicate = path(this).and(other)

infix fun String.or(other: RequestPredicate): RequestPredicate = path(this).or(other)

infix fun RequestPredicate.and(other: RequestPredicate): RequestPredicate = this.and(other)

infix fun RequestPredicate.or(other: RequestPredicate): RequestPredicate = this.or(other)

operator fun RequestPredicate.not(): RequestPredicate = this.negate()

fun accept(mediaType: MediaType): RequestPredicate = RequestPredicates.accept(mediaType)

fun contentType(mediaType: MediaType): RequestPredicate = RequestPredicates.contentType(mediaType)
Expand All @@ -53,38 +63,68 @@ class CoroutineRouterFunctionDsl {

suspend fun String.nest(r: CoroutineRoutes) = path(this).nest(r)

fun <T : CoroutineServerResponse> GET(pattern: String, f: CoroutineHandlerFunction<T>) {
routes += route(RequestPredicates.GET(pattern), f.asHandlerFunction())
}

fun GET(pattern: String): RequestPredicate = RequestPredicates.GET(pattern)

fun <T : CoroutineServerResponse> POST(pattern: String, f: CoroutineHandlerFunction<T>) {
routes += route(RequestPredicates.POST(pattern), f.asHandlerFunction())
}

fun POST(pattern: String): RequestPredicate = RequestPredicates.POST(pattern)

fun <T : CoroutineServerResponse> PUT(pattern: String, f: CoroutineHandlerFunction<T>) {
routes += route(RequestPredicates.PUT(pattern), f.asHandlerFunction())
}

fun PUT(pattern: String): RequestPredicate = RequestPredicates.PUT(pattern)

fun <T : CoroutineServerResponse> DELETE(pattern: String, f: CoroutineHandlerFunction<T>) {
routes += route(RequestPredicates.DELETE(pattern), f.asHandlerFunction())
}

fun DELETE(pattern: String): RequestPredicate = RequestPredicates.DELETE(pattern)

fun GET(pattern: String) = RequestPredicates.GET(pattern)
fun <T : CoroutineServerResponse> PATCH(pattern: String, f: CoroutineHandlerFunction<T>) {
routes += route(RequestPredicates.PATCH(pattern), f.asHandlerFunction())
}

fun <T: CoroutineServerResponse> GET(pattern: String, f: CoroutineHandlerFunction<T>) {
routes += route(RequestPredicates.GET(pattern), f.asHandlerFunction())
fun PATCH(pattern: String): RequestPredicate = RequestPredicates.PATCH(pattern)

fun <T : CoroutineServerResponse> HEAD(pattern: String, f: CoroutineHandlerFunction<T>) {
routes += route(RequestPredicates.HEAD(pattern), f.asHandlerFunction())
}

fun HEAD(pattern: String): RequestPredicate = RequestPredicates.HEAD(pattern)

fun <T : CoroutineServerResponse> OPTIONS(pattern: String, f: CoroutineHandlerFunction<T>) {
routes += route(RequestPredicates.OPTIONS(pattern), f.asHandlerFunction())
}

fun OPTIONS(pattern: String): RequestPredicate = RequestPredicates.OPTIONS(pattern)

fun path(pattern: String): RequestPredicate = RequestPredicates.path(pattern)

fun <T: CoroutineServerResponse> pathExtension(extension: String, f: CoroutineHandlerFunction<T>) {
routes += route(RequestPredicates.pathExtension(extension), f.asHandlerFunction())
}

fun <T: CoroutineServerResponse> POST(pattern: String, f: CoroutineHandlerFunction<T>) {
routes += route(RequestPredicates.POST(pattern), f.asHandlerFunction())
}

fun router(): RouterFunction<ServerResponse> {
return routes.reduce(RouterFunction<ServerResponse>::and)
}

operator fun <T: CoroutineServerResponse> RequestPredicate.invoke(f: CoroutineHandlerFunction<T>) {
operator fun <T : CoroutineServerResponse> RequestPredicate.invoke(f: CoroutineHandlerFunction<T>) {
routes += route(this, f.asHandlerFunction())
}

private fun <T: CoroutineServerResponse> CoroutineHandlerFunction<T>.asHandlerFunction() = HandlerFunction {
private fun <T : CoroutineServerResponse> CoroutineHandlerFunction<T>.asHandlerFunction() = HandlerFunction {
mono(Unconfined) {
[email protected](org.springframework.web.coroutine.function.server.CoroutineServerRequest(it))?.extractServerResponse()
}
}
}

operator fun <T: ServerResponse> RouterFunction<T>.plus(other: RouterFunction<T>) =
this.and(other)
operator fun <T : ServerResponse> RouterFunction<T>.plus(other: RouterFunction<T>) =
this.and(other)
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@ package org.springframework.web.coroutine.function.server

import kotlinx.coroutines.experimental.channels.ReceiveChannel
import kotlinx.coroutines.experimental.reactive.awaitFirstOrDefault
import kotlinx.coroutines.experimental.reactive.awaitFirstOrNull
import kotlinx.coroutines.experimental.reactive.openSubscription
import org.springframework.http.server.coroutine.CoroutineServerHttpRequest
import org.springframework.web.server.CoroutineWebSession
import org.springframework.web.coroutine.function.CoroutineBodyExtractor
import org.springframework.web.reactive.function.server.ServerRequest
import org.springframework.web.server.session.asCoroutineWebSession
import java.net.URI
import java.security.Principal

interface CoroutineServerRequest {
fun <T> body(extractor: CoroutineBodyExtractor<T, CoroutineServerHttpRequest>): T
Expand All @@ -45,23 +47,25 @@ interface CoroutineServerRequest {

fun extractServerRequest(): ServerRequest

suspend fun principal(): Principal?

companion object {
operator fun invoke(req: ServerRequest) = DefaultCoroutineServerRequest(req)
}
}

class DefaultCoroutineServerRequest(val req: ServerRequest): CoroutineServerRequest {
override fun <T> body(extractor: CoroutineBodyExtractor<T, CoroutineServerHttpRequest>): T =
req.body(extractor.asBodyExtractor())
req.body(extractor.asBodyExtractor())

override fun <T> body(extractor: CoroutineBodyExtractor<T, CoroutineServerHttpRequest>, hints: Map<String, Any>): T =
req.body(extractor.asBodyExtractor(), hints)
req.body(extractor.asBodyExtractor(), hints)

override suspend fun <T> body(elementClass: Class<out T>): T? =
req.bodyToMono(elementClass).awaitFirstOrDefault(null)
req.bodyToMono(elementClass).awaitFirstOrNull()

override fun <T> bodyToReceiveChannel(elementClass: Class<out T>): ReceiveChannel<T> =
req.bodyToFlux(elementClass).openSubscription()
req.bodyToFlux(elementClass).openSubscription()

override fun headers(): ServerRequest.Headers = req.headers()

Expand All @@ -72,6 +76,9 @@ class DefaultCoroutineServerRequest(val req: ServerRequest): CoroutineServerRequ
override fun uri(): URI = req.uri()

override fun extractServerRequest(): ServerRequest = req

override suspend fun principal(): Principal? =
req.principal().awaitFirstOrNull()
}

//fun CoroutineServerRequest.language() = TODO()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,18 +51,18 @@ interface CoroutineServerResponse {
fun created(location: URI): CoroutineBodyBuilder = ServerResponse.created(location).asCoroutineBodyBuilder()

fun from(other: CoroutineServerResponse): CoroutineBodyBuilder =
ServerResponse.from(other.extractServerResponse()).asCoroutineBodyBuilder()
ServerResponse.from(other.extractServerResponse()).asCoroutineBodyBuilder()

fun noContent(): CoroutineHeadersBuilder =
(ServerResponse.noContent() as ServerResponse.BodyBuilder).asCoroutineBodyBuilder()
(ServerResponse.noContent() as ServerResponse.BodyBuilder).asCoroutineBodyBuilder()

fun notFound(): CoroutineHeadersBuilder =
(ServerResponse.notFound() as ServerResponse.BodyBuilder).asCoroutineBodyBuilder()
(ServerResponse.notFound() as ServerResponse.BodyBuilder).asCoroutineBodyBuilder()

fun ok(): CoroutineBodyBuilder = ServerResponse.ok().asCoroutineBodyBuilder()

fun permanentRedirect(location: URI): CoroutineBodyBuilder =
ServerResponse.permanentRedirect(location).asCoroutineBodyBuilder()
ServerResponse.permanentRedirect(location).asCoroutineBodyBuilder()

fun seeOther(location: URI): CoroutineBodyBuilder = ServerResponse.seeOther(location).asCoroutineBodyBuilder()

Expand All @@ -71,10 +71,10 @@ interface CoroutineServerResponse {
fun status(status: HttpStatus): CoroutineBodyBuilder = ServerResponse.status(status).asCoroutineBodyBuilder()

fun temporaryRedirect(location: URI): CoroutineBodyBuilder =
ServerResponse.temporaryRedirect(location).asCoroutineBodyBuilder()
ServerResponse.temporaryRedirect(location).asCoroutineBodyBuilder()

fun unprocessableEntity(): CoroutineBodyBuilder =
ServerResponse.unprocessableEntity().asCoroutineBodyBuilder()
ServerResponse.unprocessableEntity().asCoroutineBodyBuilder()
}
}

Expand All @@ -83,6 +83,8 @@ interface CoroutineHeadersBuilder {

fun allow(allowedMethods: Set<HttpMethod>): CoroutineHeadersBuilder

suspend fun build(): CoroutineServerResponse?

fun cacheControl(cacheControl: CacheControl): CoroutineHeadersBuilder

fun cookie(cookie: ResponseCookie): CoroutineHeadersBuilder
Expand All @@ -103,7 +105,6 @@ interface CoroutineHeadersBuilder {
}

interface CoroutineBodyBuilder: CoroutineHeadersBuilder {
suspend fun build(): CoroutineServerResponse?

suspend fun body(inserter: CoroutineBodyInserter<*, in CoroutineServerHttpResponse>): CoroutineServerResponse?

Expand Down Expand Up @@ -137,6 +138,8 @@ internal open class DefaultCoroutineHeadersBuilder<T: ServerResponse.HeadersBuil
builder.allow(allowedMethods)
}

override suspend fun build(): CoroutineServerResponse? = builder.build().asCoroutineServerResponse()

override fun cacheControl(cacheControl: CacheControl): CoroutineHeadersBuilder = apply {
builder.cacheControl(cacheControl)
}
Expand Down Expand Up @@ -172,43 +175,43 @@ internal open class DefaultCoroutineHeadersBuilder<T: ServerResponse.HeadersBuil
override fun varyBy(vararg requestHeaders: String): CoroutineHeadersBuilder = apply {
builder.varyBy(*requestHeaders)
}

suspend fun Mono<ServerResponse>.asCoroutineServerResponse(): CoroutineServerResponse? =
awaitFirstOrNull()?.let { CoroutineServerResponse(it) }
}

internal open class DefaultCoroutineBodyBuilder(builder: ServerResponse.BodyBuilder): DefaultCoroutineHeadersBuilder<ServerResponse.BodyBuilder>(builder), CoroutineBodyBuilder {
override suspend fun build(): CoroutineServerResponse? = builder.build().asCoroutineServerResponse()

override suspend fun body(inserter: CoroutineBodyInserter<*, in CoroutineServerHttpResponse>): CoroutineServerResponse? =
builder.body(inserter.asBodyInserter()).asCoroutineServerResponse()
builder.body(inserter.asBodyInserter()).asCoroutineServerResponse()

override suspend fun <T> body(value: T?, elementClass: Class<T>): CoroutineServerResponse? =
builder.body(Mono.justOrEmpty(value), elementClass as Class<T?>).asCoroutineServerResponse()
builder.body(Mono.justOrEmpty(value), elementClass as Class<T?>).asCoroutineServerResponse()

override suspend fun <T> body(channel: ReceiveChannel<T>, elementClass: Class<T>): CoroutineServerResponse? =
builder.body(channel.asPublisher(Unconfined), elementClass).asCoroutineServerResponse()
builder.body(channel.asPublisher(Unconfined), elementClass).asCoroutineServerResponse()

override fun contentType(contentType: MediaType): CoroutineBodyBuilder = apply {
builder.contentType(contentType)
}

override suspend fun render(name: String, vararg modelAttributes: Any): CoroutineServerResponse? =
builder.render(name, modelAttributes).awaitFirstOrNull()?.asCoroutineServerResponse()
builder.render(name, modelAttributes).awaitFirstOrNull()?.asCoroutineServerResponse()

override suspend fun render(name: String, model: Map<String, *>): CoroutineServerResponse? =
builder.render(name, model).awaitFirstOrNull()?.asCoroutineServerResponse()
builder.render(name, model).awaitFirstOrNull()?.asCoroutineServerResponse()

override suspend fun syncBody(body: Any): CoroutineServerResponse? =
builder.syncBody(body).asCoroutineServerResponse()

suspend fun Mono<ServerResponse>.asCoroutineServerResponse(): CoroutineServerResponse? =
awaitFirstOrNull()?.let { CoroutineServerResponse(it) }
builder.syncBody(body).asCoroutineServerResponse()
}

internal open class DefaultCoroutineRenderingResponse(resp: RenderingResponse): DefaultCoroutineServerResponse(resp), CoroutineRenderingResponse {
}

class DefaultCoroutineRenderingResponseBuilder(val builder: RenderingResponse.Builder): CoroutineRenderingResponse.Builder {
override suspend fun build(): CoroutineRenderingResponse =
CoroutineRenderingResponse(builder.build().awaitFirst())
CoroutineRenderingResponse(builder.build().awaitFirst())

override fun modelAttributes(attributes: Map<String, *>): CoroutineRenderingResponse.Builder = apply {
builder.modelAttributes(attributes)
Expand All @@ -220,13 +223,13 @@ private fun CoroutineBodyInserter<*, in CoroutineServerHttpResponse>.asBodyInser
private fun ServerResponse.BodyBuilder.asCoroutineBodyBuilder(): CoroutineBodyBuilder = CoroutineBodyBuilder(this)

inline suspend fun <reified T : Any> CoroutineBodyBuilder.body(channel: ReceiveChannel<T>): CoroutineServerResponse? =
body(channel, T::class.java)
body(channel, T::class.java)

inline suspend fun <reified T: Any> CoroutineBodyBuilder.body(value: T?): CoroutineServerResponse? =
body(value, T::class.java)
body(value, T::class.java)

inline fun <T : CoroutineServerResponse> ServerResponse.asCoroutineServerResponse(): T =
when (this) {
is RenderingResponse -> CoroutineRenderingResponse(this)
else -> CoroutineServerResponse(this)
} as T
when (this) {
is RenderingResponse -> CoroutineRenderingResponse(this)
else -> CoroutineServerResponse(this)
} as T
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,19 @@

package org.springframework.kotlin.experimental.coroutine.web

import org.json.JSONObject
import org.springframework.boot.autoconfigure.EnableAutoConfiguration
import org.springframework.boot.test.context.SpringBootTest
import org.springframework.boot.web.server.LocalServerPort
import org.springframework.http.HttpStatus
import org.springframework.http.RequestEntity
import org.springframework.kotlin.experimental.coroutine.IntSpecConfiguration
import org.springframework.web.client.RestOperations
import org.springframework.web.client.RestTemplate
import spock.lang.Specification

import static org.springframework.boot.test.context.SpringBootTest.WebEnvironment
import static org.springframework.http.MediaType.APPLICATION_JSON

@SpringBootTest(classes = [IntSpecConfiguration, FunctionalStyleRoutesConfiguration], webEnvironment = WebEnvironment.RANDOM_PORT)
@EnableAutoConfiguration
Expand All @@ -37,10 +40,45 @@ class FunctionalStyleRoutesIntSpec extends Specification {

def "should handle functional style defined GET request"() {
when:
def result = restTemplate.getForEntity("http://localhost:$port/test/simple/HelloWorld", String)
def result = restTemplate.getForEntity("http://localhost:$port/test-functional/simple/HelloWorld", String)

then:
result.statusCode == HttpStatus.OK
result.body == "HelloWorld"
}

def "should handle functional style defined POST request"() {
when:
def body = new JSONObject().put("key", "ping").toString()
def request = RequestEntity.post(URI.create("http://localhost:$port/test-functional/simple")).contentType(APPLICATION_JSON).body(body)
def result = restTemplate.exchange(request, String)

then:
result.statusCode == HttpStatus.CREATED
result.headers.getLocation().toString() == "http://localhost:$port/test-functional/simple"
result.body == new JSONObject().put("ping", "pong").toString()
}

def "should handle functional style defined PUT request"() {
when:
def body = new JSONObject().put("key", "ping").toString()
def request = RequestEntity.put(URI.create("http://localhost:$port/test-functional/simple/1234"))
.contentType(APPLICATION_JSON)
.build()
def result = restTemplate.exchange(request, String)

then:
result.statusCode == HttpStatus.NO_CONTENT
result.body == null
}

def "should handle functional style defined DELETE request"() {
when:
def request = RequestEntity.delete(URI.create("http://localhost:$port/test-functional/simple/4321")).build()
def result = restTemplate.exchange(request, String)

then:
result.statusCode == HttpStatus.NO_CONTENT
result.body == null
}
}
Loading