diff --git a/querydsl-libraries/querydsl-r2dbc/src/main/java/com/querydsl/r2dbc/AbstractR2DBCQuery.java b/querydsl-libraries/querydsl-r2dbc/src/main/java/com/querydsl/r2dbc/AbstractR2DBCQuery.java index 9210e7045..257bbb07d 100644 --- a/querydsl-libraries/querydsl-r2dbc/src/main/java/com/querydsl/r2dbc/AbstractR2DBCQuery.java +++ b/querydsl-libraries/querydsl-r2dbc/src/main/java/com/querydsl/r2dbc/AbstractR2DBCQuery.java @@ -40,10 +40,12 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.function.Function; import java.util.logging.Level; import java.util.logging.Logger; import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.Nullable; +import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -212,31 +214,32 @@ protected Configuration getConfiguration() { @SuppressWarnings("unchecked") @Override public Flux fetch() { - return getConnection() - .flatMapMany( - conn -> { - var expr = (Expression) queryMixin.getMetadata().getProjection(); - var serializer = serialize(false); - var mapper = createMapper(expr); - - var constants = serializer.getConstants(); - var originalSql = serializer.toString(); - var sql = - R2dbcUtils.replaceBindingArguments( - configuration.getBindMarkerFactory().create(), constants, originalSql); - - var statement = conn.createStatement(sql); - BindTarget bindTarget = new StatementWrapper(statement); - - setParameters( - bindTarget, - configuration.getBindMarkerFactory().create(), - constants, - serializer.getConstantPaths(), - getMetadata().getParams()); - - return Flux.from(statement.execute()).flatMap(result -> result.map(mapper::map)); - }); + Function> work = + connection -> { + var expr = (Expression) queryMixin.getMetadata().getProjection(); + var serializer = serialize(false); + var mapper = createMapper(expr); + + var constants = serializer.getConstants(); + var originalSql = serializer.toString(); + var sql = + R2dbcUtils.replaceBindingArguments( + configuration.getBindMarkerFactory().create(), constants, originalSql); + + var statement = connection.createStatement(sql); + BindTarget bindTarget = new StatementWrapper(statement); + + setParameters( + bindTarget, + configuration.getBindMarkerFactory().create(), + constants, + serializer.getConstantPaths(), + getMetadata().getParams()); + + return Flux.from(statement.execute()).flatMap(result -> result.map(mapper::map)); + }; + + return usingConnectionMany(work); } private Mapper createMapper(Expression expr) { @@ -322,32 +325,53 @@ private Mono unsafeCount() { logQuery(sql, constants); - return getConnection() - .flatMap( - connection -> { - var statement = getStatement(connection, sql); - BindTarget bindTarget = new StatementWrapper(statement); - - setParameters( - bindTarget, - configuration.getBindMarkerFactory().create(), - constants, - serializer.getConstantPaths(), - getMetadata().getParams()); - - return Mono.from(statement.execute()) - .flatMap(result -> Mono.from(result.map((row, rowMetadata) -> row.get(0)))) - .map( - o -> { - if (Integer.class.isAssignableFrom(o.getClass())) { - return ((Integer) o).longValue(); - } - - return (Long) o; - }) - .defaultIfEmpty(0L) - .doOnError(e -> Mono.error(configuration.translate(sql, constants, e))); - }); + Function> work = + connection -> { + var statement = getStatement(connection, sql); + BindTarget bindTarget = new StatementWrapper(statement); + + setParameters( + bindTarget, + configuration.getBindMarkerFactory().create(), + constants, + serializer.getConstantPaths(), + getMetadata().getParams()); + + return Mono.from(statement.execute()) + .flatMap(result -> Mono.from(result.map((row, rowMetadata) -> row.get(0)))) + .map( + o -> { + if (Integer.class.isAssignableFrom(o.getClass())) { + return ((Integer) o).longValue(); + } + + return (Long) o; + }) + .defaultIfEmpty(0L) + .doOnError(e -> Mono.error(configuration.translate(sql, constants, e))); + }; + + return usingConnection(work); + } + + private Flux usingConnectionMany(Function> callback) { + if (connProvider != null) { + return connProvider.withConnectionMany(callback); + } + if (conn != null) { + return Flux.defer(() -> Flux.from(callback.apply(conn))); + } + return Flux.error(new IllegalStateException("No connection provided")); + } + + private Mono usingConnection(Function> callback) { + if (connProvider != null) { + return connProvider.withConnection(callback); + } + if (conn != null) { + return Mono.defer(() -> callback.apply(conn)); + } + return Mono.error(new IllegalStateException("No connection provided")); } protected void logQuery(String queryString, Collection parameters) { diff --git a/querydsl-libraries/querydsl-r2dbc/src/main/java/com/querydsl/r2dbc/R2DBCConnectionProvider.java b/querydsl-libraries/querydsl-r2dbc/src/main/java/com/querydsl/r2dbc/R2DBCConnectionProvider.java index 1ee89b615..f8778784e 100644 --- a/querydsl-libraries/querydsl-r2dbc/src/main/java/com/querydsl/r2dbc/R2DBCConnectionProvider.java +++ b/querydsl-libraries/querydsl-r2dbc/src/main/java/com/querydsl/r2dbc/R2DBCConnectionProvider.java @@ -1,6 +1,11 @@ package com.querydsl.r2dbc; import io.r2dbc.spi.Connection; +import io.r2dbc.spi.ConnectionFactory; +import java.util.Objects; +import java.util.function.Function; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; /** R2DBC connection provider */ @@ -13,4 +18,73 @@ public interface R2DBCConnectionProvider { * @return the connection of the current transaction */ Mono getConnection(); + + /** + * Release the connection returned from {@link #getConnection()} once the consumer has finished + * using it. Default implementation is a no-op which is suitable when the provider exposes + * externally managed connections (for example, a transaction scoped connection). + * + * @param connection connection to release + * @return completion signal for the release + */ + default Mono release(Connection connection) { + return Mono.empty(); + } + + /** + * Execute the given callback with a managed connection and release it afterwards. + * + * @param callback work to perform with the managed connection + * @param result type + * @return mono emitting the callback result + */ + default Mono withConnection(Function> callback) { + Objects.requireNonNull(callback, "callback"); + return Mono.usingWhen( + getConnection(), + connection -> Mono.defer(() -> callback.apply(connection)), + this::release, + (connection, error) -> release(connection), + connection -> release(connection)); + } + + /** + * Execute the given callback that returns a {@link Publisher} sequence with a managed connection + * and release the connection afterwards. + * + * @param callback work to perform with the managed connection + * @param element type emitted by the publisher + * @return flux emitting the callback results + */ + default Flux withConnectionMany(Function> callback) { + Objects.requireNonNull(callback, "callback"); + return Flux.usingWhen( + getConnection(), + connection -> Flux.from(callback.apply(connection)), + this::release, + (connection, error) -> release(connection), + connection -> release(connection)); + } + + /** + * Create a {@link R2DBCConnectionProvider} backed by a {@link ConnectionFactory}. Each invocation + * creates a new connection from the factory and ensures it is closed after use. + * + * @param connectionFactory source of connections + * @return provider that creates and closes connections per use + */ + static R2DBCConnectionProvider from(ConnectionFactory connectionFactory) { + Objects.requireNonNull(connectionFactory, "connectionFactory"); + return new R2DBCConnectionProvider() { + @Override + public Mono getConnection() { + return Mono.from(connectionFactory.create()); + } + + @Override + public Mono release(Connection connection) { + return Mono.from(connection.close()); + } + }; + } } diff --git a/querydsl-libraries/querydsl-r2dbc/src/test/java/com/querydsl/r2dbc/Connections.java b/querydsl-libraries/querydsl-r2dbc/src/test/java/com/querydsl/r2dbc/Connections.java index b9b87b7cc..605102f23 100644 --- a/querydsl-libraries/querydsl-r2dbc/src/test/java/com/querydsl/r2dbc/Connections.java +++ b/querydsl-libraries/querydsl-r2dbc/src/test/java/com/querydsl/r2dbc/Connections.java @@ -74,7 +74,7 @@ public final class Connections { private static boolean sqlServerInited, h2Inited, mysqlInited, postgresqlInited; public static R2DBCConnectionProvider getR2DBCConnectionProvider(String url) { - return () -> Mono.from(getConnectionProvider(url).create()); + return R2DBCConnectionProvider.from(getConnectionProvider(url)); } public static ConnectionFactory getConnectionProvider(String url) {