Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,50 @@

package org.springframework.web.coroutine.function.client

import kotlinx.coroutines.experimental.reactive.awaitFirstOrNull
import org.springframework.http.HttpStatus
import org.springframework.http.ResponseCookie
import org.springframework.http.ResponseEntity
import org.springframework.util.MultiValueMap
import org.springframework.web.reactive.function.client.ClientResponse
import org.springframework.web.reactive.function.client.ClientResponse.Headers
import org.springframework.web.reactive.function.client.ExchangeStrategies

interface CoroutineClientResponse {
// fun <T> body(extractor: BodyExtractor<T, in ClientHttpResponse>): T

fun statusCode(): HttpStatus

fun headers(): Headers

fun cookies(): MultiValueMap<String, ResponseCookie>

fun strategies(): ExchangeStrategies

suspend fun <T> body(elementClass: Class<T>): T?

suspend fun <T> toEntity(bodyType: Class<T>): ResponseEntity<T?>

}

internal class DefaultCoroutineClientResponse(
private val clientResponse: ClientResponse
): CoroutineClientResponse {
) : CoroutineClientResponse {

override fun statusCode(): HttpStatus = clientResponse.statusCode()

override fun headers(): Headers = clientResponse.headers()

override fun cookies(): MultiValueMap<String, ResponseCookie> = clientResponse.cookies()

override fun strategies(): ExchangeStrategies = clientResponse.strategies()

override suspend fun <T> body(elementClass: Class<T>): T? =
clientResponse.bodyToMono(elementClass).awaitFirstOrNull()

override suspend fun <T> toEntity(bodyType: Class<T>): ResponseEntity<T?> =
ResponseEntity(clientResponse.bodyToMono(bodyType).awaitFirstOrNull(), headers().asHttpHeaders(), statusCode())
}

suspend inline fun <reified T : Any> CoroutineClientResponse.body(): T? = body(T::class.java)

}
suspend inline fun <reified T : Any> CoroutineClientResponse.toEntity(): ResponseEntity<T?> = toEntity(T::class.java)
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,16 @@ package org.springframework.web.coroutine.function.client

import org.springframework.http.HttpHeaders
import org.springframework.http.HttpMethod
import org.springframework.http.HttpStatus
import org.springframework.http.MediaType
import org.springframework.kotlin.experimental.coroutine.web.awaitFirstOrNull
import org.springframework.util.MultiValueMap
import org.springframework.web.coroutine.function.client.CoroutineWebClient.CoroutineResponseSpec
import org.springframework.web.coroutine.function.client.CoroutineWebClient.RequestBodySpec
import org.springframework.web.coroutine.function.client.CoroutineWebClient.RequestHeadersSpec
import org.springframework.web.reactive.function.client.ClientResponse
import org.springframework.web.reactive.function.client.WebClient
import reactor.core.publisher.Mono
import java.net.URI
import java.nio.charset.Charset
import java.time.ZonedDateTime
Expand All @@ -46,6 +52,15 @@ interface CoroutineWebClient {
interface RequestBodyUriSpec: RequestBodySpec, RequestHeadersUriSpec<RequestBodySpec>

interface RequestBodySpec: RequestHeadersSpec<RequestBodySpec> {

fun contentLength(contentLength: Long): RequestBodySpec

fun contentType(contentType: MediaType): RequestBodySpec

fun <T> body(body: T, elementClass: Class<T>): RequestHeadersSpec<*>

fun syncBody(body: Any): RequestHeadersSpec<*>

}

interface RequestHeadersUriSpec<T: RequestHeadersSpec<T>>: UriSpec<T>, RequestHeadersSpec<T> {
Expand Down Expand Up @@ -80,13 +95,18 @@ interface CoroutineWebClient {

fun attributes(attributesConsumer: (Map<String, Any>) -> Unit): T

suspend fun retrieve(): CoroutineResponseSpec
fun retrieve(): CoroutineResponseSpec

suspend fun exchange(): CoroutineClientResponse?
}

interface CoroutineResponseSpec {
suspend fun <T> body(clazz: Class<T>): T?

fun onStatus(
statusPredicate: (HttpStatus) -> Boolean,
exceptionFunction: (ClientResponse) -> Mono<out Throwable>
): CoroutineResponseSpec
}

companion object {
Expand All @@ -97,6 +117,8 @@ interface CoroutineWebClient {

suspend inline fun <reified T : Any> CoroutineWebClient.CoroutineResponseSpec.body(): T? = body(T::class.java)

inline fun <reified T : Any> CoroutineWebClient.RequestBodySpec.body(body: T): RequestHeadersSpec<*> = body(body, T::class.java)

open class DefaultCoroutineWebClient(
private val client: WebClient
) : CoroutineWebClient {
Expand All @@ -121,18 +143,26 @@ open class DefaultCoroutineWebClient(
}

private fun WebClient.ResponseSpec.asCoroutineResponseSpec(): CoroutineWebClient.CoroutineResponseSpec =
DefaultCoroutineResponseSpec(this)
DefaultCoroutineResponseSpec(this)

open class DefaultCoroutineResponseSpec(
private val spec: WebClient.ResponseSpec
): CoroutineWebClient.CoroutineResponseSpec {
override suspend fun <T> body(clazz: Class<T>): T? =
spec.bodyToMono(clazz).awaitFirstOrNull()

override suspend fun <T> body(clazz: Class<T>): T? = spec.bodyToMono(clazz).awaitFirstOrNull()

override fun onStatus(
statusPredicate: (HttpStatus) -> Boolean,
exceptionFunction: (ClientResponse) -> Mono<out Throwable>
): CoroutineResponseSpec = apply {
spec.onStatus(statusPredicate, exceptionFunction)
}
}

open class DefaultRequestBodyUriSpec(
private val spec: WebClient.RequestBodyUriSpec
): CoroutineWebClient.RequestBodyUriSpec {

override fun uri(uri: String, vararg uriVariables: Any): CoroutineWebClient.RequestBodySpec = apply {
spec.uri(uri, *uriVariables)
}
Expand Down Expand Up @@ -189,6 +219,22 @@ open class DefaultRequestBodyUriSpec(
DefaultCoroutineClientResponse(it)
}

override suspend fun retrieve(): CoroutineWebClient.CoroutineResponseSpec =
override fun retrieve(): CoroutineWebClient.CoroutineResponseSpec =
spec.retrieve().asCoroutineResponseSpec()

override fun contentLength(contentLength: Long): RequestBodySpec = apply {
spec.contentLength(contentLength)
}

override fun contentType(contentType: MediaType): RequestBodySpec = apply {
spec.contentType(contentType)
}

override fun <T> body(body: T, elementClass: Class<T>): RequestHeadersSpec<*> = apply {
spec.body(Mono.justOrEmpty(body), elementClass)
}

override fun syncBody(body: Any): RequestHeadersSpec<*> = apply {
spec.syncBody(body)
}
}
Original file line number Diff line number Diff line change
@@ -1,38 +1,100 @@
package org.springframework.kotlin.experimental.coroutine


import org.springframework.boot.autoconfigure.EnableAutoConfiguration
import org.springframework.boot.test.context.SpringBootTest
import org.springframework.boot.web.server.LocalServerPort
import org.springframework.http.HttpStatus
import org.springframework.http.MediaType
import org.springframework.kotlin.experimental.coroutine.web.WebConfiguration
import org.springframework.web.coroutine.function.client.CoroutineClientResponse
import org.springframework.web.coroutine.function.client.CoroutineWebClient
import org.springframework.web.coroutine.function.client.CoroutineWebClientUtils
import org.springframework.web.reactive.function.client.WebClientResponseException
import spock.lang.Specification

import static org.springframework.kotlin.experimental.coroutine.TestUtilsKt.runBlocking

@SpringBootTest(classes = [IntSpecConfiguration, WebConfiguration], webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT)
@EnableAutoConfiguration
class CoroutineWebClientSpec extends Specification {

@LocalServerPort
private int port

def "should be able to access endpoint with CoroutineWebClient with custom header"() {
def "should be able to access endpoint with CoroutineWebClient with custom header and body using retrieve()"() {
when:
final String url = "http://localhost:$port"
final CoroutineWebClient.CoroutineResponseSpec spec = CoroutineWebClientUtils.createCoroutineWebClient(url)
.post()
.uri("/postWithHeaderAndBodyTest")
.header("X-Coroutine-Test", "123456")
.contentType(MediaType.APPLICATION_JSON)
.body([text: "free-text", num: 1, flag: true], Map)
.retrieve()

final Map<String, Object> response = runBlocking { cont ->
spec.body(Map, cont)
}

then:
response["header"] == "123456"
response["body"] == [text: "free-text", num: 1, flag: true]
}

def "should be able to access endpoint with CoroutineWebClient with custom header and body using exchange()"() {
when:
final String url = "http://localhost:$port"
final CoroutineWebClient.CoroutineResponseSpec spec = runBlocking { cont ->
final CoroutineClientResponse resp = runBlocking { cont ->
CoroutineWebClientUtils.createCoroutineWebClient(url)
.post()
.uri("/postWithHeaderTest")
.uri("/postWithHeaderAndBodyTest")
.header("X-Coroutine-Test", "123456")
.retrieve(cont)
.contentType(MediaType.APPLICATION_JSON)
.body([text: "free-text", num: 1, flag: true], Map)
.exchange(cont)
}

final Map<String, Object> responseBody = runBlocking { cont ->
resp.body(Map, cont)
}

then:
resp.statusCode() == HttpStatus.CREATED
responseBody["header"] == "123456"
responseBody["body"] == [text: "free-text", num: 1, flag: true]
}

def "should fail to access endpoint with CoroutineWebClient with invalid content-type using exchange()"() {
when:
final String url = "http://localhost:$port"
final CoroutineClientResponse resp = runBlocking { cont ->
CoroutineWebClientUtils.createCoroutineWebClient(url)
.post()
.uri("/postWithHeaderAndBodyTest")
.contentType(MediaType.IMAGE_JPEG)
.exchange(cont)
}

final String response = runBlocking { cont ->
spec.body(String, cont)
then:
resp.statusCode() == HttpStatus.UNSUPPORTED_MEDIA_TYPE
}

def "should fail to access endpoint with CoroutineWebClient with invalid content-type using retrieve()"() {
when:
final String url = "http://localhost:$port"
runBlocking { cont ->
CoroutineWebClientUtils.createCoroutineWebClient(url)
.post()
.uri("/postWithHeaderAndBodyTest")
.contentType(MediaType.IMAGE_GIF)
.retrieve()
.body(Map, cont)
}

then:
response == "123456"
def e = thrown WebClientResponseException
e.statusCode == HttpStatus.UNSUPPORTED_MEDIA_TYPE
e.statusText == "Unsupported Media Type"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,16 @@ import kotlinx.coroutines.experimental.CommonPool
import kotlinx.coroutines.experimental.channels.ReceiveChannel
import kotlinx.coroutines.experimental.channels.produce
import kotlinx.coroutines.experimental.delay
import org.springframework.http.HttpStatus.CREATED
import org.springframework.http.MediaType
import org.springframework.kotlin.experimental.coroutine.annotation.Coroutine
import org.springframework.kotlin.experimental.coroutine.context.COMMON_POOL
import org.springframework.web.bind.annotation.GetMapping
import org.springframework.web.bind.annotation.PathVariable
import org.springframework.web.bind.annotation.PostMapping
import org.springframework.web.bind.annotation.RequestBody
import org.springframework.web.bind.annotation.RequestHeader
import org.springframework.web.bind.annotation.ResponseStatus
import org.springframework.web.bind.annotation.RestController

@RestController
Expand Down Expand Up @@ -72,19 +75,19 @@ open class CoroutineController {

@GetMapping("/sseChannelMultiply/{a}/{b}", produces = [MediaType.TEXT_EVENT_STREAM_VALUE])
open fun sseChannelMultiply(@PathVariable("a") a: Int, @PathVariable("b") b: Int) = //: ReceiveChannel<Int>
channelMultiply(a, b)
channelMultiply(a, b)

@GetMapping("/sseDelayedChannelMultiply/{a}/{b}", produces = [MediaType.TEXT_EVENT_STREAM_VALUE])
open suspend fun sseDelayedChannelMultiply(@PathVariable("a") a: Int, @PathVariable("b") b: Int): ReceiveChannel<Int> =
delayedChannelMultiply(a, b)
delayedChannelMultiply(a, b)

@GetMapping("/sseChannelMultiplyList/{a}/{b}", produces = [MediaType.TEXT_EVENT_STREAM_VALUE])
open suspend fun sseChannelMultiplyList(@PathVariable("a") a: Int, @PathVariable("b") b: Int): List<Int> =
channelMultiplyList(a, b)
channelMultiplyList(a, b)

@GetMapping("/sseSuspendChannelMultiply/{a}/{b}", produces = [MediaType.TEXT_EVENT_STREAM_VALUE])
open suspend fun sseSuspendChannelMultiply(@PathVariable("a") a: Int, @PathVariable("b") b: Int) = //: ReceiveChannel<Int>
suspendChannelMultiply(a, b)
suspendChannelMultiply(a, b)

@GetMapping("/test", produces = [MediaType.TEXT_EVENT_STREAM_VALUE])
open fun test() = produce(CommonPool) {
Expand All @@ -105,6 +108,11 @@ open class CoroutineController {
@PostMapping("/postTest")
open suspend fun postTest(): Int = 123456

@PostMapping("/postWithHeaderTest")
open suspend fun postWithHeaderTest(@RequestHeader("X-Coroutine-Test") headerValue: String): String = headerValue
@ResponseStatus(CREATED)
@PostMapping("/postWithHeaderAndBodyTest", consumes = [MediaType.APPLICATION_JSON_VALUE])
open suspend fun postWithHeaderAndBodyTest(
@RequestHeader("X-Coroutine-Test") headerValue: String,
@RequestBody body: Map<String, Any>
): Map<String, Any> = mapOf("header" to headerValue, "body" to body)

}