diff --git a/spring-webflux-kotlin-coroutine/src/main/kotlin/org/springframework/web/coroutine/function/client/CoroutineClientResponse.kt b/spring-webflux-kotlin-coroutine/src/main/kotlin/org/springframework/web/coroutine/function/client/CoroutineClientResponse.kt index 5815c87..ffd15ab 100644 --- a/spring-webflux-kotlin-coroutine/src/main/kotlin/org/springframework/web/coroutine/function/client/CoroutineClientResponse.kt +++ b/spring-webflux-kotlin-coroutine/src/main/kotlin/org/springframework/web/coroutine/function/client/CoroutineClientResponse.kt @@ -16,14 +16,50 @@ package org.springframework.web.coroutine.function.client +import kotlinx.coroutines.experimental.reactive.awaitFirstOrNull +import org.springframework.http.HttpStatus +import org.springframework.http.ResponseCookie +import org.springframework.http.ResponseEntity +import org.springframework.util.MultiValueMap import org.springframework.web.reactive.function.client.ClientResponse +import org.springframework.web.reactive.function.client.ClientResponse.Headers +import org.springframework.web.reactive.function.client.ExchangeStrategies interface CoroutineClientResponse { -// fun body(extractor: BodyExtractor): T + + fun statusCode(): HttpStatus + + fun headers(): Headers + + fun cookies(): MultiValueMap + + fun strategies(): ExchangeStrategies + + suspend fun body(elementClass: Class): T? + + suspend fun toEntity(bodyType: Class): ResponseEntity + } internal class DefaultCoroutineClientResponse( private val clientResponse: ClientResponse -): CoroutineClientResponse { +) : CoroutineClientResponse { + + override fun statusCode(): HttpStatus = clientResponse.statusCode() + + override fun headers(): Headers = clientResponse.headers() + + override fun cookies(): MultiValueMap = clientResponse.cookies() + + override fun strategies(): ExchangeStrategies = clientResponse.strategies() + + override suspend fun body(elementClass: Class): T? = + clientResponse.bodyToMono(elementClass).awaitFirstOrNull() + + override suspend fun toEntity(bodyType: Class): ResponseEntity = + ResponseEntity(clientResponse.bodyToMono(bodyType).awaitFirstOrNull(), headers().asHttpHeaders(), statusCode()) +} + +suspend inline fun CoroutineClientResponse.body(): T? = body(T::class.java) -} \ No newline at end of file +suspend inline fun CoroutineClientResponse.toEntity(): ResponseEntity = toEntity(T::class.java) \ No newline at end of file diff --git a/spring-webflux-kotlin-coroutine/src/main/kotlin/org/springframework/web/coroutine/function/client/CoroutineWebClient.kt b/spring-webflux-kotlin-coroutine/src/main/kotlin/org/springframework/web/coroutine/function/client/CoroutineWebClient.kt index 6ab4bc0..1a2a9e4 100644 --- a/spring-webflux-kotlin-coroutine/src/main/kotlin/org/springframework/web/coroutine/function/client/CoroutineWebClient.kt +++ b/spring-webflux-kotlin-coroutine/src/main/kotlin/org/springframework/web/coroutine/function/client/CoroutineWebClient.kt @@ -18,10 +18,16 @@ package org.springframework.web.coroutine.function.client import org.springframework.http.HttpHeaders import org.springframework.http.HttpMethod +import org.springframework.http.HttpStatus import org.springframework.http.MediaType import org.springframework.kotlin.experimental.coroutine.web.awaitFirstOrNull import org.springframework.util.MultiValueMap +import org.springframework.web.coroutine.function.client.CoroutineWebClient.CoroutineResponseSpec +import org.springframework.web.coroutine.function.client.CoroutineWebClient.RequestBodySpec +import org.springframework.web.coroutine.function.client.CoroutineWebClient.RequestHeadersSpec +import org.springframework.web.reactive.function.client.ClientResponse import org.springframework.web.reactive.function.client.WebClient +import reactor.core.publisher.Mono import java.net.URI import java.nio.charset.Charset import java.time.ZonedDateTime @@ -46,6 +52,15 @@ interface CoroutineWebClient { interface RequestBodyUriSpec: RequestBodySpec, RequestHeadersUriSpec interface RequestBodySpec: RequestHeadersSpec { + + fun contentLength(contentLength: Long): RequestBodySpec + + fun contentType(contentType: MediaType): RequestBodySpec + + fun body(body: T, elementClass: Class): RequestHeadersSpec<*> + + fun syncBody(body: Any): RequestHeadersSpec<*> + } interface RequestHeadersUriSpec>: UriSpec, RequestHeadersSpec { @@ -80,13 +95,18 @@ interface CoroutineWebClient { fun attributes(attributesConsumer: (Map) -> Unit): T - suspend fun retrieve(): CoroutineResponseSpec + fun retrieve(): CoroutineResponseSpec suspend fun exchange(): CoroutineClientResponse? } interface CoroutineResponseSpec { suspend fun body(clazz: Class): T? + + fun onStatus( + statusPredicate: (HttpStatus) -> Boolean, + exceptionFunction: (ClientResponse) -> Mono + ): CoroutineResponseSpec } companion object { @@ -97,6 +117,8 @@ interface CoroutineWebClient { suspend inline fun CoroutineWebClient.CoroutineResponseSpec.body(): T? = body(T::class.java) +inline fun CoroutineWebClient.RequestBodySpec.body(body: T): RequestHeadersSpec<*> = body(body, T::class.java) + open class DefaultCoroutineWebClient( private val client: WebClient ) : CoroutineWebClient { @@ -121,18 +143,26 @@ open class DefaultCoroutineWebClient( } private fun WebClient.ResponseSpec.asCoroutineResponseSpec(): CoroutineWebClient.CoroutineResponseSpec = - DefaultCoroutineResponseSpec(this) + DefaultCoroutineResponseSpec(this) open class DefaultCoroutineResponseSpec( private val spec: WebClient.ResponseSpec ): CoroutineWebClient.CoroutineResponseSpec { - override suspend fun body(clazz: Class): T? = - spec.bodyToMono(clazz).awaitFirstOrNull() + + override suspend fun body(clazz: Class): T? = spec.bodyToMono(clazz).awaitFirstOrNull() + + override fun onStatus( + statusPredicate: (HttpStatus) -> Boolean, + exceptionFunction: (ClientResponse) -> Mono + ): CoroutineResponseSpec = apply { + spec.onStatus(statusPredicate, exceptionFunction) + } } open class DefaultRequestBodyUriSpec( private val spec: WebClient.RequestBodyUriSpec ): CoroutineWebClient.RequestBodyUriSpec { + override fun uri(uri: String, vararg uriVariables: Any): CoroutineWebClient.RequestBodySpec = apply { spec.uri(uri, *uriVariables) } @@ -189,6 +219,22 @@ open class DefaultRequestBodyUriSpec( DefaultCoroutineClientResponse(it) } - override suspend fun retrieve(): CoroutineWebClient.CoroutineResponseSpec = + override fun retrieve(): CoroutineWebClient.CoroutineResponseSpec = spec.retrieve().asCoroutineResponseSpec() + + override fun contentLength(contentLength: Long): RequestBodySpec = apply { + spec.contentLength(contentLength) + } + + override fun contentType(contentType: MediaType): RequestBodySpec = apply { + spec.contentType(contentType) + } + + override fun body(body: T, elementClass: Class): RequestHeadersSpec<*> = apply { + spec.body(Mono.justOrEmpty(body), elementClass) + } + + override fun syncBody(body: Any): RequestHeadersSpec<*> = apply { + spec.syncBody(body) + } } \ No newline at end of file diff --git a/spring-webflux-kotlin-coroutine/src/test/groovy/org/springframework/kotlin/experimental/coroutine/CoroutineWebClientSpec.groovy b/spring-webflux-kotlin-coroutine/src/test/groovy/org/springframework/kotlin/experimental/coroutine/CoroutineWebClientSpec.groovy index 2126652..e0c57c0 100644 --- a/spring-webflux-kotlin-coroutine/src/test/groovy/org/springframework/kotlin/experimental/coroutine/CoroutineWebClientSpec.groovy +++ b/spring-webflux-kotlin-coroutine/src/test/groovy/org/springframework/kotlin/experimental/coroutine/CoroutineWebClientSpec.groovy @@ -1,12 +1,16 @@ package org.springframework.kotlin.experimental.coroutine + import org.springframework.boot.autoconfigure.EnableAutoConfiguration import org.springframework.boot.test.context.SpringBootTest import org.springframework.boot.web.server.LocalServerPort +import org.springframework.http.HttpStatus import org.springframework.http.MediaType import org.springframework.kotlin.experimental.coroutine.web.WebConfiguration +import org.springframework.web.coroutine.function.client.CoroutineClientResponse import org.springframework.web.coroutine.function.client.CoroutineWebClient import org.springframework.web.coroutine.function.client.CoroutineWebClientUtils +import org.springframework.web.reactive.function.client.WebClientResponseException import spock.lang.Specification import static org.springframework.kotlin.experimental.coroutine.TestUtilsKt.runBlocking @@ -14,25 +18,83 @@ import static org.springframework.kotlin.experimental.coroutine.TestUtilsKt.runB @SpringBootTest(classes = [IntSpecConfiguration, WebConfiguration], webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT) @EnableAutoConfiguration class CoroutineWebClientSpec extends Specification { + @LocalServerPort private int port - def "should be able to access endpoint with CoroutineWebClient with custom header"() { + def "should be able to access endpoint with CoroutineWebClient with custom header and body using retrieve()"() { + when: + final String url = "http://localhost:$port" + final CoroutineWebClient.CoroutineResponseSpec spec = CoroutineWebClientUtils.createCoroutineWebClient(url) + .post() + .uri("/postWithHeaderAndBodyTest") + .header("X-Coroutine-Test", "123456") + .contentType(MediaType.APPLICATION_JSON) + .body([text: "free-text", num: 1, flag: true], Map) + .retrieve() + + final Map response = runBlocking { cont -> + spec.body(Map, cont) + } + + then: + response["header"] == "123456" + response["body"] == [text: "free-text", num: 1, flag: true] + } + + def "should be able to access endpoint with CoroutineWebClient with custom header and body using exchange()"() { when: final String url = "http://localhost:$port" - final CoroutineWebClient.CoroutineResponseSpec spec = runBlocking { cont -> + final CoroutineClientResponse resp = runBlocking { cont -> CoroutineWebClientUtils.createCoroutineWebClient(url) .post() - .uri("/postWithHeaderTest") + .uri("/postWithHeaderAndBodyTest") .header("X-Coroutine-Test", "123456") - .retrieve(cont) + .contentType(MediaType.APPLICATION_JSON) + .body([text: "free-text", num: 1, flag: true], Map) + .exchange(cont) + } + + final Map responseBody = runBlocking { cont -> + resp.body(Map, cont) + } + + then: + resp.statusCode() == HttpStatus.CREATED + responseBody["header"] == "123456" + responseBody["body"] == [text: "free-text", num: 1, flag: true] + } + + def "should fail to access endpoint with CoroutineWebClient with invalid content-type using exchange()"() { + when: + final String url = "http://localhost:$port" + final CoroutineClientResponse resp = runBlocking { cont -> + CoroutineWebClientUtils.createCoroutineWebClient(url) + .post() + .uri("/postWithHeaderAndBodyTest") + .contentType(MediaType.IMAGE_JPEG) + .exchange(cont) } - final String response = runBlocking { cont -> - spec.body(String, cont) + then: + resp.statusCode() == HttpStatus.UNSUPPORTED_MEDIA_TYPE + } + + def "should fail to access endpoint with CoroutineWebClient with invalid content-type using retrieve()"() { + when: + final String url = "http://localhost:$port" + runBlocking { cont -> + CoroutineWebClientUtils.createCoroutineWebClient(url) + .post() + .uri("/postWithHeaderAndBodyTest") + .contentType(MediaType.IMAGE_GIF) + .retrieve() + .body(Map, cont) } then: - response == "123456" + def e = thrown WebClientResponseException + e.statusCode == HttpStatus.UNSUPPORTED_MEDIA_TYPE + e.statusText == "Unsupported Media Type" } } \ No newline at end of file diff --git a/spring-webflux-kotlin-coroutine/src/test/kotlin/org/springframework/kotlin/experimental/coroutine/web/CoroutineController.kt b/spring-webflux-kotlin-coroutine/src/test/kotlin/org/springframework/kotlin/experimental/coroutine/web/CoroutineController.kt index 7f36182..2d82c7a 100644 --- a/spring-webflux-kotlin-coroutine/src/test/kotlin/org/springframework/kotlin/experimental/coroutine/web/CoroutineController.kt +++ b/spring-webflux-kotlin-coroutine/src/test/kotlin/org/springframework/kotlin/experimental/coroutine/web/CoroutineController.kt @@ -20,13 +20,16 @@ import kotlinx.coroutines.experimental.CommonPool import kotlinx.coroutines.experimental.channels.ReceiveChannel import kotlinx.coroutines.experimental.channels.produce import kotlinx.coroutines.experimental.delay +import org.springframework.http.HttpStatus.CREATED import org.springframework.http.MediaType import org.springframework.kotlin.experimental.coroutine.annotation.Coroutine import org.springframework.kotlin.experimental.coroutine.context.COMMON_POOL import org.springframework.web.bind.annotation.GetMapping import org.springframework.web.bind.annotation.PathVariable import org.springframework.web.bind.annotation.PostMapping +import org.springframework.web.bind.annotation.RequestBody import org.springframework.web.bind.annotation.RequestHeader +import org.springframework.web.bind.annotation.ResponseStatus import org.springframework.web.bind.annotation.RestController @RestController @@ -72,19 +75,19 @@ open class CoroutineController { @GetMapping("/sseChannelMultiply/{a}/{b}", produces = [MediaType.TEXT_EVENT_STREAM_VALUE]) open fun sseChannelMultiply(@PathVariable("a") a: Int, @PathVariable("b") b: Int) = //: ReceiveChannel - channelMultiply(a, b) + channelMultiply(a, b) @GetMapping("/sseDelayedChannelMultiply/{a}/{b}", produces = [MediaType.TEXT_EVENT_STREAM_VALUE]) open suspend fun sseDelayedChannelMultiply(@PathVariable("a") a: Int, @PathVariable("b") b: Int): ReceiveChannel = - delayedChannelMultiply(a, b) + delayedChannelMultiply(a, b) @GetMapping("/sseChannelMultiplyList/{a}/{b}", produces = [MediaType.TEXT_EVENT_STREAM_VALUE]) open suspend fun sseChannelMultiplyList(@PathVariable("a") a: Int, @PathVariable("b") b: Int): List = - channelMultiplyList(a, b) + channelMultiplyList(a, b) @GetMapping("/sseSuspendChannelMultiply/{a}/{b}", produces = [MediaType.TEXT_EVENT_STREAM_VALUE]) open suspend fun sseSuspendChannelMultiply(@PathVariable("a") a: Int, @PathVariable("b") b: Int) = //: ReceiveChannel - suspendChannelMultiply(a, b) + suspendChannelMultiply(a, b) @GetMapping("/test", produces = [MediaType.TEXT_EVENT_STREAM_VALUE]) open fun test() = produce(CommonPool) { @@ -105,6 +108,11 @@ open class CoroutineController { @PostMapping("/postTest") open suspend fun postTest(): Int = 123456 - @PostMapping("/postWithHeaderTest") - open suspend fun postWithHeaderTest(@RequestHeader("X-Coroutine-Test") headerValue: String): String = headerValue + @ResponseStatus(CREATED) + @PostMapping("/postWithHeaderAndBodyTest", consumes = [MediaType.APPLICATION_JSON_VALUE]) + open suspend fun postWithHeaderAndBodyTest( + @RequestHeader("X-Coroutine-Test") headerValue: String, + @RequestBody body: Map + ): Map = mapOf("header" to headerValue, "body" to body) + } \ No newline at end of file