@@ -15,14 +15,19 @@ import io.r2dbc.postgresql.PostgresqlConnectionFactory
1515import io.r2dbc.postgresql.client.SSLMode
1616import io.r2dbc.spi.Connection
1717import io.r2dbc.spi.ConnectionFactory
18+ import kotlinx.coroutines.async
19+ import kotlinx.coroutines.awaitAll
20+ import kotlinx.coroutines.coroutineScope
1821import kotlinx.coroutines.flow.Flow
1922import kotlinx.coroutines.flow.flow
2023import kotlinx.coroutines.reactive.awaitFirst
2124import kotlinx.coroutines.reactive.awaitFirstOrNull
25+ import kotlinx.coroutines.reactor.awaitSingle
2226import kotlinx.html.*
2327import reactor.core.publisher.Flux
2428import reactor.core.publisher.Mono
2529import java.time.Duration
30+ import java.util.concurrent.ThreadLocalRandom
2631import kotlin.random.Random
2732
2833const val HELLO_WORLD = " Hello, World!"
@@ -36,7 +41,6 @@ fun Application.main() {
3641 val dbConnFactory = configurePostgresR2DBC(config)
3742
3843 val helloWorldContent = TextContent (" Hello, World!" , ContentType .Text .Plain )
39- val random = Random .Default
4044
4145 install(DefaultHeaders )
4246
@@ -50,23 +54,23 @@ fun Application.main() {
5054 }
5155
5256 get(" /db" ) {
53- val request = getWorld(dbConnFactory, random )
57+ val request = getWorld(dbConnFactory)
5458 val result = request.awaitFirstOrNull()
5559
5660 call.respondJson(result)
5761 }
5862
59- fun selectWorlds (queries : Int , random : Random ): Flow <World > = flow {
63+ fun selectWorlds (queries : Int ): Flow <World > = flow {
6064 repeat(queries) {
61- emit(getWorld(dbConnFactory, random ).awaitFirst())
65+ emit(getWorld(dbConnFactory).awaitFirst())
6266 }
6367 }
6468
6569 get(" /queries" ) {
6670 val queries = call.queries()
6771
6872 val result = buildList {
69- selectWorlds(queries, random ).collect {
73+ selectWorlds(queries).collect {
7074 add(it)
7175 }
7276 }
@@ -113,48 +117,56 @@ fun Application.main() {
113117 get(" /updates" ) {
114118 val queries = call.queries()
115119
116- val worlds = selectWorlds(queries, random)
117-
118- val worldsUpdated = buildList {
119- worlds.collect { world ->
120- world.randomNumber = random.nextInt(DB_ROWS ) + 1
121- add(world)
122-
123- Mono .usingWhen(dbConnFactory.create(), { connection ->
124- Mono .from(
125- connection.createStatement(UPDATE_QUERY )
126- .bind(0 , world.randomNumber)
127- .bind(1 , world.id)
128- .execute()
129- ).flatMap { Mono .from(it.rowsUpdated) }
130- }, Connection ::close).awaitFirstOrNull()
131- }
132- }
120+ val worlds = fetchWorldsConcurrently(dbConnFactory, queries)
121+ val updatedWorlds = worlds.map {
122+ it.copy(randomNumber = ThreadLocalRandom .current().nextInt(1 , DB_ROWS + 1 ))
123+ }.sortedBy { it.id }
133124
134- call.respondJson(worldsUpdated)
125+ Mono .usingWhen(dbConnFactory.create(), { connection ->
126+ connection.beginTransaction()
127+ val statement = connection.createStatement(UPDATE_QUERY )
128+ updatedWorlds.forEach { world ->
129+ statement.bind(0 , world.randomNumber).bind(1 , world.id).add()
130+ }
131+ Mono .from(statement.execute())
132+ .flatMap { Mono .from(it.rowsUpdated) }
133+ .then(Mono .from(connection.commitTransaction()))
134+ },
135+ Connection ::close,
136+ { connection, _ -> connection.rollbackTransaction() },
137+ { connection -> connection.rollbackTransaction() }
138+ ).awaitSingle()
139+
140+ call.respondJson(updatedWorlds)
135141 }
136142 }
137143}
138144
139145private fun getWorld (
140- dbConnFactory : ConnectionFactory , random : Random
146+ dbConnFactory : ConnectionFactory , random : ThreadLocalRandom = ThreadLocalRandom .current()
141147): Mono <World > = Mono .usingWhen(dbConnFactory.create(), { connection ->
142148 Mono .from(connection.createStatement(WORLD_QUERY )
143149 .bind(" $1" , random.nextInt(DB_ROWS ) + 1 )
144150 .execute())
145- .flatMap { r ->
146- Mono .from(r.map { row, _ ->
147- val id = row.get(0 , Int ::class .java)
148- val randomNumber = row.get(1 , Int ::class .java)
149- if (id != null && randomNumber != null ) {
150- World (id, randomNumber)
151- } else {
152- throw IllegalStateException (" Database returned null values for required fields" )
153- }
151+ .flatMap { result ->
152+ Mono .from(result.map { row, _ ->
153+ World (
154+ row.get(0 , Int ::class .java)
155+ ? : error(" id is null" ),
156+ row.get(1 , Int ::class .java)
157+ ? : error(" randomNumber is null" )
158+ )
154159 })
155160 }
156161}, Connection ::close)
157162
163+ suspend fun fetchWorldsConcurrently (factory : ConnectionFactory , count : Int ): List <World > =
164+ coroutineScope {
165+ (0 until count).map {
166+ async { getWorld(factory, ThreadLocalRandom .current()).awaitSingle() }
167+ }.awaitAll()
168+ }
169+
158170private fun configurePostgresR2DBC (config : ApplicationConfig ): ConnectionFactory {
159171 val cfo = PostgresqlConnectionConfiguration .builder()
160172 .host(config.property(" db.host" ).getString())
0 commit comments