diff --git a/spring-webflux-kotlin-coroutine/src/main/kotlin/org/springframework/web/coroutine/function/server/CoroutineRouterFunction.kt b/spring-webflux-kotlin-coroutine/src/main/kotlin/org/springframework/web/coroutine/function/server/CoroutineRouterFunction.kt index 30131ed..b4cfa86 100644 --- a/spring-webflux-kotlin-coroutine/src/main/kotlin/org/springframework/web/coroutine/function/server/CoroutineRouterFunction.kt +++ b/spring-webflux-kotlin-coroutine/src/main/kotlin/org/springframework/web/coroutine/function/server/CoroutineRouterFunction.kt @@ -41,8 +41,18 @@ class CoroutineRouterFunctionDsl { infix fun RequestPredicate.and(other: String): RequestPredicate = this.and(path(other)) + infix fun RequestPredicate.or(other: String): RequestPredicate = this.or(path(other)) + + infix fun String.and(other: RequestPredicate): RequestPredicate = path(this).and(other) + + infix fun String.or(other: RequestPredicate): RequestPredicate = path(this).or(other) + + infix fun RequestPredicate.and(other: RequestPredicate): RequestPredicate = this.and(other) + infix fun RequestPredicate.or(other: RequestPredicate): RequestPredicate = this.or(other) + operator fun RequestPredicate.not(): RequestPredicate = this.negate() + fun accept(mediaType: MediaType): RequestPredicate = RequestPredicates.accept(mediaType) fun contentType(mediaType: MediaType): RequestPredicate = RequestPredicates.contentType(mediaType) @@ -53,38 +63,68 @@ class CoroutineRouterFunctionDsl { suspend fun String.nest(r: CoroutineRoutes) = path(this).nest(r) + fun GET(pattern: String, f: CoroutineHandlerFunction) { + routes += route(RequestPredicates.GET(pattern), f.asHandlerFunction()) + } + + fun GET(pattern: String): RequestPredicate = RequestPredicates.GET(pattern) + + fun POST(pattern: String, f: CoroutineHandlerFunction) { + routes += route(RequestPredicates.POST(pattern), f.asHandlerFunction()) + } + + fun POST(pattern: String): RequestPredicate = RequestPredicates.POST(pattern) + + fun PUT(pattern: String, f: CoroutineHandlerFunction) { + routes += route(RequestPredicates.PUT(pattern), f.asHandlerFunction()) + } + + fun PUT(pattern: String): RequestPredicate = RequestPredicates.PUT(pattern) + + fun DELETE(pattern: String, f: CoroutineHandlerFunction) { + routes += route(RequestPredicates.DELETE(pattern), f.asHandlerFunction()) + } + fun DELETE(pattern: String): RequestPredicate = RequestPredicates.DELETE(pattern) - fun GET(pattern: String) = RequestPredicates.GET(pattern) + fun PATCH(pattern: String, f: CoroutineHandlerFunction) { + routes += route(RequestPredicates.PATCH(pattern), f.asHandlerFunction()) + } - fun GET(pattern: String, f: CoroutineHandlerFunction) { - routes += route(RequestPredicates.GET(pattern), f.asHandlerFunction()) + fun PATCH(pattern: String): RequestPredicate = RequestPredicates.PATCH(pattern) + + fun HEAD(pattern: String, f: CoroutineHandlerFunction) { + routes += route(RequestPredicates.HEAD(pattern), f.asHandlerFunction()) } + fun HEAD(pattern: String): RequestPredicate = RequestPredicates.HEAD(pattern) + + fun OPTIONS(pattern: String, f: CoroutineHandlerFunction) { + routes += route(RequestPredicates.OPTIONS(pattern), f.asHandlerFunction()) + } + + fun OPTIONS(pattern: String): RequestPredicate = RequestPredicates.OPTIONS(pattern) + fun path(pattern: String): RequestPredicate = RequestPredicates.path(pattern) fun pathExtension(extension: String, f: CoroutineHandlerFunction) { routes += route(RequestPredicates.pathExtension(extension), f.asHandlerFunction()) } - fun POST(pattern: String, f: CoroutineHandlerFunction) { - routes += route(RequestPredicates.POST(pattern), f.asHandlerFunction()) - } - fun router(): RouterFunction { return routes.reduce(RouterFunction::and) } - operator fun RequestPredicate.invoke(f: CoroutineHandlerFunction) { + operator fun RequestPredicate.invoke(f: CoroutineHandlerFunction) { routes += route(this, f.asHandlerFunction()) } - private fun CoroutineHandlerFunction.asHandlerFunction() = HandlerFunction { + private fun CoroutineHandlerFunction.asHandlerFunction() = HandlerFunction { mono(Unconfined) { this@asHandlerFunction.invoke(org.springframework.web.coroutine.function.server.CoroutineServerRequest(it))?.extractServerResponse() } } } -operator fun RouterFunction.plus(other: RouterFunction) = - this.and(other) \ No newline at end of file +operator fun RouterFunction.plus(other: RouterFunction) = + this.and(other) \ No newline at end of file diff --git a/spring-webflux-kotlin-coroutine/src/main/kotlin/org/springframework/web/coroutine/function/server/CoroutineServerRequest.kt b/spring-webflux-kotlin-coroutine/src/main/kotlin/org/springframework/web/coroutine/function/server/CoroutineServerRequest.kt index b5a5f93..47be4ec 100644 --- a/spring-webflux-kotlin-coroutine/src/main/kotlin/org/springframework/web/coroutine/function/server/CoroutineServerRequest.kt +++ b/spring-webflux-kotlin-coroutine/src/main/kotlin/org/springframework/web/coroutine/function/server/CoroutineServerRequest.kt @@ -18,6 +18,7 @@ package org.springframework.web.coroutine.function.server import kotlinx.coroutines.experimental.channels.ReceiveChannel import kotlinx.coroutines.experimental.reactive.awaitFirstOrDefault +import kotlinx.coroutines.experimental.reactive.awaitFirstOrNull import kotlinx.coroutines.experimental.reactive.openSubscription import org.springframework.http.server.coroutine.CoroutineServerHttpRequest import org.springframework.web.server.CoroutineWebSession @@ -25,6 +26,7 @@ import org.springframework.web.coroutine.function.CoroutineBodyExtractor import org.springframework.web.reactive.function.server.ServerRequest import org.springframework.web.server.session.asCoroutineWebSession import java.net.URI +import java.security.Principal interface CoroutineServerRequest { fun body(extractor: CoroutineBodyExtractor): T @@ -45,6 +47,8 @@ interface CoroutineServerRequest { fun extractServerRequest(): ServerRequest + suspend fun principal(): Principal? + companion object { operator fun invoke(req: ServerRequest) = DefaultCoroutineServerRequest(req) } @@ -52,16 +56,16 @@ interface CoroutineServerRequest { class DefaultCoroutineServerRequest(val req: ServerRequest): CoroutineServerRequest { override fun body(extractor: CoroutineBodyExtractor): T = - req.body(extractor.asBodyExtractor()) + req.body(extractor.asBodyExtractor()) override fun body(extractor: CoroutineBodyExtractor, hints: Map): T = - req.body(extractor.asBodyExtractor(), hints) + req.body(extractor.asBodyExtractor(), hints) override suspend fun body(elementClass: Class): T? = - req.bodyToMono(elementClass).awaitFirstOrDefault(null) + req.bodyToMono(elementClass).awaitFirstOrNull() override fun bodyToReceiveChannel(elementClass: Class): ReceiveChannel = - req.bodyToFlux(elementClass).openSubscription() + req.bodyToFlux(elementClass).openSubscription() override fun headers(): ServerRequest.Headers = req.headers() @@ -72,6 +76,9 @@ class DefaultCoroutineServerRequest(val req: ServerRequest): CoroutineServerRequ override fun uri(): URI = req.uri() override fun extractServerRequest(): ServerRequest = req + + override suspend fun principal(): Principal? = + req.principal().awaitFirstOrNull() } //fun CoroutineServerRequest.language() = TODO() diff --git a/spring-webflux-kotlin-coroutine/src/main/kotlin/org/springframework/web/coroutine/function/server/CoroutineServerResponse.kt b/spring-webflux-kotlin-coroutine/src/main/kotlin/org/springframework/web/coroutine/function/server/CoroutineServerResponse.kt index 8d720eb..55e5d1d 100644 --- a/spring-webflux-kotlin-coroutine/src/main/kotlin/org/springframework/web/coroutine/function/server/CoroutineServerResponse.kt +++ b/spring-webflux-kotlin-coroutine/src/main/kotlin/org/springframework/web/coroutine/function/server/CoroutineServerResponse.kt @@ -51,18 +51,18 @@ interface CoroutineServerResponse { fun created(location: URI): CoroutineBodyBuilder = ServerResponse.created(location).asCoroutineBodyBuilder() fun from(other: CoroutineServerResponse): CoroutineBodyBuilder = - ServerResponse.from(other.extractServerResponse()).asCoroutineBodyBuilder() + ServerResponse.from(other.extractServerResponse()).asCoroutineBodyBuilder() fun noContent(): CoroutineHeadersBuilder = - (ServerResponse.noContent() as ServerResponse.BodyBuilder).asCoroutineBodyBuilder() + (ServerResponse.noContent() as ServerResponse.BodyBuilder).asCoroutineBodyBuilder() fun notFound(): CoroutineHeadersBuilder = - (ServerResponse.notFound() as ServerResponse.BodyBuilder).asCoroutineBodyBuilder() + (ServerResponse.notFound() as ServerResponse.BodyBuilder).asCoroutineBodyBuilder() fun ok(): CoroutineBodyBuilder = ServerResponse.ok().asCoroutineBodyBuilder() fun permanentRedirect(location: URI): CoroutineBodyBuilder = - ServerResponse.permanentRedirect(location).asCoroutineBodyBuilder() + ServerResponse.permanentRedirect(location).asCoroutineBodyBuilder() fun seeOther(location: URI): CoroutineBodyBuilder = ServerResponse.seeOther(location).asCoroutineBodyBuilder() @@ -71,10 +71,10 @@ interface CoroutineServerResponse { fun status(status: HttpStatus): CoroutineBodyBuilder = ServerResponse.status(status).asCoroutineBodyBuilder() fun temporaryRedirect(location: URI): CoroutineBodyBuilder = - ServerResponse.temporaryRedirect(location).asCoroutineBodyBuilder() + ServerResponse.temporaryRedirect(location).asCoroutineBodyBuilder() fun unprocessableEntity(): CoroutineBodyBuilder = - ServerResponse.unprocessableEntity().asCoroutineBodyBuilder() + ServerResponse.unprocessableEntity().asCoroutineBodyBuilder() } } @@ -83,6 +83,8 @@ interface CoroutineHeadersBuilder { fun allow(allowedMethods: Set): CoroutineHeadersBuilder + suspend fun build(): CoroutineServerResponse? + fun cacheControl(cacheControl: CacheControl): CoroutineHeadersBuilder fun cookie(cookie: ResponseCookie): CoroutineHeadersBuilder @@ -103,7 +105,6 @@ interface CoroutineHeadersBuilder { } interface CoroutineBodyBuilder: CoroutineHeadersBuilder { - suspend fun build(): CoroutineServerResponse? suspend fun body(inserter: CoroutineBodyInserter<*, in CoroutineServerHttpResponse>): CoroutineServerResponse? @@ -137,6 +138,8 @@ internal open class DefaultCoroutineHeadersBuilder.asCoroutineServerResponse(): CoroutineServerResponse? = + awaitFirstOrNull()?.let { CoroutineServerResponse(it) } } internal open class DefaultCoroutineBodyBuilder(builder: ServerResponse.BodyBuilder): DefaultCoroutineHeadersBuilder(builder), CoroutineBodyBuilder { override suspend fun build(): CoroutineServerResponse? = builder.build().asCoroutineServerResponse() override suspend fun body(inserter: CoroutineBodyInserter<*, in CoroutineServerHttpResponse>): CoroutineServerResponse? = - builder.body(inserter.asBodyInserter()).asCoroutineServerResponse() + builder.body(inserter.asBodyInserter()).asCoroutineServerResponse() override suspend fun body(value: T?, elementClass: Class): CoroutineServerResponse? = - builder.body(Mono.justOrEmpty(value), elementClass as Class).asCoroutineServerResponse() + builder.body(Mono.justOrEmpty(value), elementClass as Class).asCoroutineServerResponse() override suspend fun body(channel: ReceiveChannel, elementClass: Class): CoroutineServerResponse? = - builder.body(channel.asPublisher(Unconfined), elementClass).asCoroutineServerResponse() + builder.body(channel.asPublisher(Unconfined), elementClass).asCoroutineServerResponse() override fun contentType(contentType: MediaType): CoroutineBodyBuilder = apply { builder.contentType(contentType) } override suspend fun render(name: String, vararg modelAttributes: Any): CoroutineServerResponse? = - builder.render(name, modelAttributes).awaitFirstOrNull()?.asCoroutineServerResponse() + builder.render(name, modelAttributes).awaitFirstOrNull()?.asCoroutineServerResponse() override suspend fun render(name: String, model: Map): CoroutineServerResponse? = - builder.render(name, model).awaitFirstOrNull()?.asCoroutineServerResponse() + builder.render(name, model).awaitFirstOrNull()?.asCoroutineServerResponse() override suspend fun syncBody(body: Any): CoroutineServerResponse? = - builder.syncBody(body).asCoroutineServerResponse() - - suspend fun Mono.asCoroutineServerResponse(): CoroutineServerResponse? = - awaitFirstOrNull()?.let { CoroutineServerResponse(it) } + builder.syncBody(body).asCoroutineServerResponse() } internal open class DefaultCoroutineRenderingResponse(resp: RenderingResponse): DefaultCoroutineServerResponse(resp), CoroutineRenderingResponse { @@ -208,7 +211,7 @@ internal open class DefaultCoroutineRenderingResponse(resp: RenderingResponse): class DefaultCoroutineRenderingResponseBuilder(val builder: RenderingResponse.Builder): CoroutineRenderingResponse.Builder { override suspend fun build(): CoroutineRenderingResponse = - CoroutineRenderingResponse(builder.build().awaitFirst()) + CoroutineRenderingResponse(builder.build().awaitFirst()) override fun modelAttributes(attributes: Map): CoroutineRenderingResponse.Builder = apply { builder.modelAttributes(attributes) @@ -220,13 +223,13 @@ private fun CoroutineBodyInserter<*, in CoroutineServerHttpResponse>.asBodyInser private fun ServerResponse.BodyBuilder.asCoroutineBodyBuilder(): CoroutineBodyBuilder = CoroutineBodyBuilder(this) inline suspend fun CoroutineBodyBuilder.body(channel: ReceiveChannel): CoroutineServerResponse? = - body(channel, T::class.java) + body(channel, T::class.java) inline suspend fun CoroutineBodyBuilder.body(value: T?): CoroutineServerResponse? = - body(value, T::class.java) + body(value, T::class.java) inline fun ServerResponse.asCoroutineServerResponse(): T = - when (this) { - is RenderingResponse -> CoroutineRenderingResponse(this) - else -> CoroutineServerResponse(this) - } as T \ No newline at end of file + when (this) { + is RenderingResponse -> CoroutineRenderingResponse(this) + else -> CoroutineServerResponse(this) + } as T \ No newline at end of file diff --git a/spring-webflux-kotlin-coroutine/src/test/groovy/org/springframework/kotlin/experimental/coroutine/web/FunctionalStyleRoutesIntSpec.groovy b/spring-webflux-kotlin-coroutine/src/test/groovy/org/springframework/kotlin/experimental/coroutine/web/FunctionalStyleRoutesIntSpec.groovy index 40077fe..4c9599f 100644 --- a/spring-webflux-kotlin-coroutine/src/test/groovy/org/springframework/kotlin/experimental/coroutine/web/FunctionalStyleRoutesIntSpec.groovy +++ b/spring-webflux-kotlin-coroutine/src/test/groovy/org/springframework/kotlin/experimental/coroutine/web/FunctionalStyleRoutesIntSpec.groovy @@ -16,16 +16,19 @@ package org.springframework.kotlin.experimental.coroutine.web +import org.json.JSONObject import org.springframework.boot.autoconfigure.EnableAutoConfiguration import org.springframework.boot.test.context.SpringBootTest import org.springframework.boot.web.server.LocalServerPort import org.springframework.http.HttpStatus +import org.springframework.http.RequestEntity import org.springframework.kotlin.experimental.coroutine.IntSpecConfiguration import org.springframework.web.client.RestOperations import org.springframework.web.client.RestTemplate import spock.lang.Specification import static org.springframework.boot.test.context.SpringBootTest.WebEnvironment +import static org.springframework.http.MediaType.APPLICATION_JSON @SpringBootTest(classes = [IntSpecConfiguration, FunctionalStyleRoutesConfiguration], webEnvironment = WebEnvironment.RANDOM_PORT) @EnableAutoConfiguration @@ -37,10 +40,45 @@ class FunctionalStyleRoutesIntSpec extends Specification { def "should handle functional style defined GET request"() { when: - def result = restTemplate.getForEntity("http://localhost:$port/test/simple/HelloWorld", String) + def result = restTemplate.getForEntity("http://localhost:$port/test-functional/simple/HelloWorld", String) then: result.statusCode == HttpStatus.OK result.body == "HelloWorld" } + + def "should handle functional style defined POST request"() { + when: + def body = new JSONObject().put("key", "ping").toString() + def request = RequestEntity.post(URI.create("http://localhost:$port/test-functional/simple")).contentType(APPLICATION_JSON).body(body) + def result = restTemplate.exchange(request, String) + + then: + result.statusCode == HttpStatus.CREATED + result.headers.getLocation().toString() == "http://localhost:$port/test-functional/simple" + result.body == new JSONObject().put("ping", "pong").toString() + } + + def "should handle functional style defined PUT request"() { + when: + def body = new JSONObject().put("key", "ping").toString() + def request = RequestEntity.put(URI.create("http://localhost:$port/test-functional/simple/1234")) + .contentType(APPLICATION_JSON) + .build() + def result = restTemplate.exchange(request, String) + + then: + result.statusCode == HttpStatus.NO_CONTENT + result.body == null + } + + def "should handle functional style defined DELETE request"() { + when: + def request = RequestEntity.delete(URI.create("http://localhost:$port/test-functional/simple/4321")).build() + def result = restTemplate.exchange(request, String) + + then: + result.statusCode == HttpStatus.NO_CONTENT + result.body == null + } } diff --git a/spring-webflux-kotlin-coroutine/src/test/kotlin/org/springframework/kotlin/experimental/coroutine/web/FunctionalStyleRoutesConfiguration.kt b/spring-webflux-kotlin-coroutine/src/test/kotlin/org/springframework/kotlin/experimental/coroutine/web/FunctionalStyleRoutesConfiguration.kt index f8553a0..356a83a 100644 --- a/spring-webflux-kotlin-coroutine/src/test/kotlin/org/springframework/kotlin/experimental/coroutine/web/FunctionalStyleRoutesConfiguration.kt +++ b/spring-webflux-kotlin-coroutine/src/test/kotlin/org/springframework/kotlin/experimental/coroutine/web/FunctionalStyleRoutesConfiguration.kt @@ -32,17 +32,32 @@ open class FunctionalStyleRoutesConfiguration { @Bean open fun apiRouter(handler: FunctionalHandler) = router { - (accept(MediaType.APPLICATION_JSON) and "/test").nest { - GET("/simple/{param}") { handler.simple(it) } + "/test-functional".nest { + (accept(MediaType.APPLICATION_JSON) and "/").nest { + GET("/simple/{param}") { handler.simpleGet(it) } + } + + ("/" and contentType(MediaType.APPLICATION_JSON)).nest { + POST("/simple") { handler.simplePost(it) } + PUT("/simple/{id}") { handler.simplePut(it) } + } + + DELETE("/simple/{id}") { handler.simpleDelete(it) } } } } @Component open class FunctionalHandler { - suspend fun simple(req: CoroutineServerRequest) = - CoroutineServerResponse - .ok() - .contentType(MediaType.TEXT_PLAIN) - .body(req.pathVariable("param")) + suspend fun simpleGet(req: CoroutineServerRequest) = + CoroutineServerResponse.ok().contentType(MediaType.TEXT_PLAIN).body(req.pathVariable("param")) + + suspend fun simplePost(req: CoroutineServerRequest) = + CoroutineServerResponse.created(req.uri()).body(mapOf(req.body>()!!["key"] to "pong")) + + suspend fun simplePut(req: CoroutineServerRequest) = + CoroutineServerResponse.noContent().build() + + suspend fun simpleDelete(req: CoroutineServerRequest) = + CoroutineServerResponse.noContent().build() }