Skip to content

Commit 746dc68

Browse files
committed
fix r2dbc update
1 parent 39a579a commit 746dc68

File tree

1 file changed

+45
-33
lines changed
  • frameworks/Kotlin/ktor/ktor-r2dbc/src/main/kotlin/org/jetbrains/ktor/benchmarks

1 file changed

+45
-33
lines changed

frameworks/Kotlin/ktor/ktor-r2dbc/src/main/kotlin/org/jetbrains/ktor/benchmarks/Hello.kt

Lines changed: 45 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,19 @@ import io.r2dbc.postgresql.PostgresqlConnectionFactory
1515
import io.r2dbc.postgresql.client.SSLMode
1616
import io.r2dbc.spi.Connection
1717
import io.r2dbc.spi.ConnectionFactory
18+
import kotlinx.coroutines.async
19+
import kotlinx.coroutines.awaitAll
20+
import kotlinx.coroutines.coroutineScope
1821
import kotlinx.coroutines.flow.Flow
1922
import kotlinx.coroutines.flow.flow
2023
import kotlinx.coroutines.reactive.awaitFirst
2124
import kotlinx.coroutines.reactive.awaitFirstOrNull
25+
import kotlinx.coroutines.reactor.awaitSingle
2226
import kotlinx.html.*
2327
import reactor.core.publisher.Flux
2428
import reactor.core.publisher.Mono
2529
import java.time.Duration
30+
import java.util.concurrent.ThreadLocalRandom
2631
import kotlin.random.Random
2732

2833
const val HELLO_WORLD = "Hello, World!"
@@ -36,7 +41,6 @@ fun Application.main() {
3641
val dbConnFactory = configurePostgresR2DBC(config)
3742

3843
val helloWorldContent = TextContent("Hello, World!", ContentType.Text.Plain)
39-
val random = Random.Default
4044

4145
install(DefaultHeaders)
4246

@@ -50,23 +54,23 @@ fun Application.main() {
5054
}
5155

5256
get("/db") {
53-
val request = getWorld(dbConnFactory, random)
57+
val request = getWorld(dbConnFactory)
5458
val result = request.awaitFirstOrNull()
5559

5660
call.respondJson(result)
5761
}
5862

59-
fun selectWorlds(queries: Int, random: Random): Flow<World> = flow {
63+
fun selectWorlds(queries: Int): Flow<World> = flow {
6064
repeat(queries) {
61-
emit(getWorld(dbConnFactory, random).awaitFirst())
65+
emit(getWorld(dbConnFactory).awaitFirst())
6266
}
6367
}
6468

6569
get("/queries") {
6670
val queries = call.queries()
6771

6872
val result = buildList {
69-
selectWorlds(queries, random).collect {
73+
selectWorlds(queries).collect {
7074
add(it)
7175
}
7276
}
@@ -113,48 +117,56 @@ fun Application.main() {
113117
get("/updates") {
114118
val queries = call.queries()
115119

116-
val worlds = selectWorlds(queries, random)
117-
118-
val worldsUpdated = buildList {
119-
worlds.collect { world ->
120-
world.randomNumber = random.nextInt(DB_ROWS) + 1
121-
add(world)
122-
123-
Mono.usingWhen(dbConnFactory.create(), { connection ->
124-
Mono.from(
125-
connection.createStatement(UPDATE_QUERY)
126-
.bind(0, world.randomNumber)
127-
.bind(1, world.id)
128-
.execute()
129-
).flatMap { Mono.from(it.rowsUpdated) }
130-
}, Connection::close).awaitFirstOrNull()
131-
}
132-
}
120+
val worlds = fetchWorldsConcurrently(dbConnFactory, queries)
121+
val updatedWorlds = worlds.map {
122+
it.copy(randomNumber = ThreadLocalRandom.current().nextInt(1, DB_ROWS + 1))
123+
}.sortedBy { it.id }
133124

134-
call.respondJson(worldsUpdated)
125+
Mono.usingWhen(dbConnFactory.create(), { connection ->
126+
connection.beginTransaction()
127+
val statement = connection.createStatement(UPDATE_QUERY)
128+
updatedWorlds.forEach { world ->
129+
statement.bind(0, world.randomNumber).bind(1, world.id).add()
130+
}
131+
Mono.from(statement.execute())
132+
.flatMap { Mono.from(it.rowsUpdated) }
133+
.then(Mono.from(connection.commitTransaction()))
134+
},
135+
Connection::close,
136+
{ connection, _ -> connection.rollbackTransaction() },
137+
{ connection -> connection.rollbackTransaction() }
138+
).awaitSingle()
139+
140+
call.respondJson(updatedWorlds)
135141
}
136142
}
137143
}
138144

139145
private fun getWorld(
140-
dbConnFactory: ConnectionFactory, random: Random
146+
dbConnFactory: ConnectionFactory, random: ThreadLocalRandom = ThreadLocalRandom.current()
141147
): Mono<World> = Mono.usingWhen(dbConnFactory.create(), { connection ->
142148
Mono.from(connection.createStatement(WORLD_QUERY)
143149
.bind("$1", random.nextInt(DB_ROWS) + 1)
144150
.execute())
145-
.flatMap { r ->
146-
Mono.from(r.map { row, _ ->
147-
val id = row.get(0, Int::class.java)
148-
val randomNumber = row.get(1, Int::class.java)
149-
if (id != null && randomNumber != null) {
150-
World(id, randomNumber)
151-
} else {
152-
throw IllegalStateException("Database returned null values for required fields")
153-
}
151+
.flatMap { result ->
152+
Mono.from(result.map { row, _ ->
153+
World(
154+
row.get(0, Int::class.java)
155+
?: error("id is null"),
156+
row.get(1, Int::class.java)
157+
?: error("randomNumber is null")
158+
)
154159
})
155160
}
156161
}, Connection::close)
157162

163+
suspend fun fetchWorldsConcurrently(factory: ConnectionFactory, count: Int): List<World> =
164+
coroutineScope {
165+
(0 until count).map {
166+
async { getWorld(factory, ThreadLocalRandom.current()).awaitSingle() }
167+
}.awaitAll()
168+
}
169+
158170
private fun configurePostgresR2DBC(config: ApplicationConfig): ConnectionFactory {
159171
val cfo = PostgresqlConnectionConfiguration.builder()
160172
.host(config.property("db.host").getString())

0 commit comments

Comments
 (0)