Skip to content

Commit 951054c

Browse files
committed
Refactor StatementCache and message flow
StatementCache now no longer prepares statements (Parse) itself but rather reports whether statement preparation is required. That allows to streamline message handling to combine multiple messages into a composite one to reduce the number of sent TCP packets. ExtendedFlowDelegate encapsulates the extended flow. [fixes #341][#373]
1 parent c85e34a commit 951054c

18 files changed

+729
-816
lines changed

src/main/java/io/r2dbc/postgresql/BoundedStatementCache.java

Lines changed: 43 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
import io.r2dbc.postgresql.client.Client;
2121
import io.r2dbc.postgresql.client.ExtendedQueryMessageFlow;
2222
import io.r2dbc.postgresql.util.Assert;
23-
import reactor.core.publisher.Mono;
23+
import reactor.util.Logger;
24+
import reactor.util.Loggers;
2425
import reactor.util.annotation.Nullable;
2526

2627
import java.util.ArrayList;
@@ -37,6 +38,8 @@
3738
*/
3839
final class BoundedStatementCache implements StatementCache {
3940

41+
private static final Logger LOGGER = Loggers.getLogger(BoundedStatementCache.class);
42+
4043
private final Map<CacheKey, String> cache = new LinkedHashMap<>(16, 0.75f, true);
4144

4245
private final Client client;
@@ -54,29 +57,47 @@ public BoundedStatementCache(Client client, int limit) {
5457
}
5558

5659
@Override
57-
public Mono<String> getName(Binding binding, String sql) {
60+
public String getName(Binding binding, String sql) {
5861
Assert.requireNonNull(binding, "binding must not be null");
5962
Assert.requireNonNull(sql, "sql must not be null");
60-
CacheKey key = new CacheKey(sql, binding.getParameterTypes());
61-
String name = get(key);
63+
64+
String name = get(new CacheKey(sql, binding.getParameterTypes()));
65+
6266
if (name != null) {
63-
return Mono.just(name);
67+
return name;
6468
}
6569

66-
Mono<Void> closeLastStatement = Mono.defer(() -> {
67-
if (getCacheSize() < this.limit) {
68-
return Mono.empty();
69-
}
70-
String lastAccessedStatementName = getAndRemoveEldest();
71-
ExceptionFactory factory = ExceptionFactory.withSql(lastAccessedStatementName);
72-
return ExtendedQueryMessageFlow
73-
.closeStatement(this.client, lastAccessedStatementName)
74-
.handle(factory::handleErrorResponse)
75-
.then();
76-
});
77-
78-
return closeLastStatement.then(parse(sql, binding.getParameterTypes()))
79-
.doOnNext(preparedName -> put(key, preparedName));
70+
return "S_" + this.counter.getAndIncrement();
71+
}
72+
73+
@Override
74+
public boolean requiresPrepare(Binding binding, String sql) {
75+
76+
Assert.requireNonNull(binding, "binding must not be null");
77+
Assert.requireNonNull(sql, "sql must not be null");
78+
79+
return get(new CacheKey(sql, binding.getParameterTypes())) == null;
80+
}
81+
82+
@Override
83+
public void put(Binding binding, String sql, String name) {
84+
85+
CacheKey key = new CacheKey(sql, binding.getParameterTypes());
86+
87+
put(key, name);
88+
89+
if (getCacheSize() <= this.limit) {
90+
return;
91+
}
92+
93+
Map.Entry<CacheKey, String> lastAccessedStatement = getAndRemoveEldest();
94+
ExceptionFactory factory = ExceptionFactory.withSql(lastAccessedStatement.getKey().sql);
95+
96+
ExtendedQueryMessageFlow
97+
.closeStatement(this.client, lastAccessedStatement.getValue())
98+
.handle(factory::handleErrorResponse)
99+
.subscribe(it -> {
100+
}, err -> LOGGER.warn(String.format("Cannot close statement %s (%s)", lastAccessedStatement.getValue(), lastAccessedStatement.getKey().sql), err));
80101
}
81102

82103
/**
@@ -87,7 +108,7 @@ public Mono<String> getName(Binding binding, String sql) {
87108
Collection<String> getCachedStatementNames() {
88109
synchronized (this.cache) {
89110
List<String> names = new ArrayList<>(this.cache.size());
90-
names.addAll(cache.values());
111+
names.addAll(this.cache.values());
91112
return names;
92113
}
93114
}
@@ -110,10 +131,10 @@ private String get(CacheKey key) {
110131
*
111132
* @return least recently used entry
112133
*/
113-
private String getAndRemoveEldest() {
134+
private Map.Entry<CacheKey, String> getAndRemoveEldest() {
114135
synchronized (this.cache) {
115136
Iterator<Map.Entry<CacheKey, String>> iterator = this.cache.entrySet().iterator();
116-
String entry = iterator.next().getValue();
137+
Map.Entry<CacheKey, String> entry = iterator.next();
117138
iterator.remove();
118139
return entry;
119140
}
@@ -149,17 +170,6 @@ public String toString() {
149170
'}';
150171
}
151172

152-
private Mono<String> parse(String sql, int[] types) {
153-
String name = "S_" + this.counter.getAndIncrement();
154-
155-
ExceptionFactory factory = ExceptionFactory.withSql(name);
156-
return ExtendedQueryMessageFlow
157-
.parse(this.client, name, sql, types)
158-
.handle(factory::handleErrorResponse)
159-
.then(Mono.just(name))
160-
.cache();
161-
}
162-
163173
static class CacheKey {
164174

165175
String sql;

src/main/java/io/r2dbc/postgresql/DisabledStatementCache.java

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,39 +17,31 @@
1717
package io.r2dbc.postgresql;
1818

1919
import io.r2dbc.postgresql.client.Binding;
20-
import io.r2dbc.postgresql.client.Client;
21-
import io.r2dbc.postgresql.client.ExtendedQueryMessageFlow;
22-
import io.r2dbc.postgresql.util.Assert;
23-
import reactor.core.publisher.Mono;
2420

2521
class DisabledStatementCache implements StatementCache {
2622

2723
private static final String UNNAMED_STATEMENT_NAME = "";
2824

29-
private final Client client;
25+
DisabledStatementCache() {
26+
}
3027

31-
DisabledStatementCache(Client client) {
32-
this.client = Assert.requireNonNull(client, "client must not be null");
28+
@Override
29+
public String getName(Binding binding, String sql) {
30+
return UNNAMED_STATEMENT_NAME;
31+
}
32+
33+
@Override
34+
public boolean requiresPrepare(Binding binding, String sql) {
35+
return true;
3336
}
3437

3538
@Override
36-
public Mono<String> getName(Binding binding, String sql) {
37-
Assert.requireNonNull(binding, "binding must not be null");
38-
Assert.requireNonNull(sql, "sql must not be null");
39-
String name = UNNAMED_STATEMENT_NAME;
40-
41-
ExceptionFactory factory = ExceptionFactory.withSql(name);
42-
return ExtendedQueryMessageFlow
43-
.parse(this.client, name, sql, binding.getParameterTypes())
44-
.handle(factory::handleErrorResponse)
45-
.then(Mono.just(name));
39+
public void put(Binding binding, String sql, String name) {
4640
}
4741

4842
@Override
4943
public String toString() {
50-
return "DisabledStatementCache{" +
51-
"client=" + this.client +
52-
'}';
44+
return "DisabledStatementCache";
5345
}
5446

5547
}
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
/*
2+
* Copyright 2020 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package io.r2dbc.postgresql;
18+
19+
import io.netty.buffer.ByteBuf;
20+
import io.netty.util.ReferenceCountUtil;
21+
import io.netty.util.ReferenceCounted;
22+
import io.r2dbc.postgresql.client.Binding;
23+
import io.r2dbc.postgresql.client.Client;
24+
import io.r2dbc.postgresql.client.ExtendedQueryMessageFlow;
25+
import io.r2dbc.postgresql.message.backend.BackendMessage;
26+
import io.r2dbc.postgresql.message.backend.BindComplete;
27+
import io.r2dbc.postgresql.message.backend.CommandComplete;
28+
import io.r2dbc.postgresql.message.backend.ErrorResponse;
29+
import io.r2dbc.postgresql.message.backend.NoData;
30+
import io.r2dbc.postgresql.message.backend.ParseComplete;
31+
import io.r2dbc.postgresql.message.backend.PortalSuspended;
32+
import io.r2dbc.postgresql.message.frontend.Bind;
33+
import io.r2dbc.postgresql.message.frontend.Close;
34+
import io.r2dbc.postgresql.message.frontend.CompositeFrontendMessage;
35+
import io.r2dbc.postgresql.message.frontend.Describe;
36+
import io.r2dbc.postgresql.message.frontend.Execute;
37+
import io.r2dbc.postgresql.message.frontend.Flush;
38+
import io.r2dbc.postgresql.message.frontend.FrontendMessage;
39+
import io.r2dbc.postgresql.message.frontend.Parse;
40+
import io.r2dbc.postgresql.message.frontend.Sync;
41+
import io.r2dbc.postgresql.util.Operators;
42+
import reactor.core.publisher.DirectProcessor;
43+
import reactor.core.publisher.Flux;
44+
import reactor.core.publisher.FluxSink;
45+
import reactor.core.publisher.Mono;
46+
import reactor.core.publisher.SynchronousSink;
47+
48+
import java.util.ArrayList;
49+
import java.util.List;
50+
import java.util.concurrent.atomic.AtomicBoolean;
51+
import java.util.function.Predicate;
52+
53+
import static io.r2dbc.postgresql.message.frontend.Execute.NO_LIMIT;
54+
import static io.r2dbc.postgresql.message.frontend.ExecutionType.PORTAL;
55+
import static io.r2dbc.postgresql.util.PredicateUtils.not;
56+
import static io.r2dbc.postgresql.util.PredicateUtils.or;
57+
58+
/**
59+
* Utility to execute the {@code Parse/Bind/Describe/Execute/Sync} portion of the <a href="https://www.postgresql.org/docs/current/static/protocol-flow.html#PROTOCOL-FLOW-EXT-QUERY">Extended query</a>
60+
* message flow.
61+
*/
62+
class ExtendedFlowDelegate {
63+
64+
static final Predicate<BackendMessage> RESULT_FRAME_FILTER = not(or(BindComplete.class::isInstance, NoData.class::isInstance));
65+
66+
/**
67+
* Execute the {@code Parse/Bind/Describe/Execute/Sync} portion of the <a href="https://www.postgresql.org/docs/current/static/protocol-flow.html#PROTOCOL-FLOW-EXT-QUERY">Extended query</a>
68+
* message flow.
69+
*
70+
* @param resources the {@link ConnectionResources} providing access to the {@link Client}
71+
* @param factory the {@link ExceptionFactory}
72+
* @param query the query to execute
73+
* @param binding the {@link Binding} to bind
74+
* @param values the binding values
75+
* @param fetchSize the fetch size to apply. Use a single {@link Execute} with fetch all if {@code fetchSize} is zero. Otherwise, perform multiple roundtrips with smaller
76+
* {@link Execute} sizes.
77+
* @return the messages received in response to the exchange
78+
* @throws IllegalArgumentException if {@code bindings}, {@code client}, {@code portalNameSupplier}, or {@code statementName} is {@code null}
79+
*/
80+
public static Flux<BackendMessage> runQuery(ConnectionResources resources, ExceptionFactory factory, String query, Binding binding, List<ByteBuf> values, int fetchSize) {
81+
82+
StatementCache cache = resources.getStatementCache();
83+
Client client = resources.getClient();
84+
85+
String name = cache.getName(binding, query);
86+
String portal = resources.getPortalNameSupplier().get();
87+
boolean prepareRequired = cache.requiresPrepare(binding, query);
88+
89+
List<FrontendMessage.DirectEncoder> messagesToSend = new ArrayList<>(6);
90+
91+
if (prepareRequired) {
92+
messagesToSend.add(new Parse(name, binding.getParameterTypes(), query));
93+
}
94+
95+
Bind bind = new Bind(portal, binding.getParameterFormats(), values, ExtendedQueryMessageFlow.resultFormat(resources.getConfiguration().isForceBinary()), name);
96+
97+
messagesToSend.add(bind);
98+
messagesToSend.add(new Describe(portal, PORTAL));
99+
100+
Flux<BackendMessage> exchange;
101+
102+
if (fetchSize == NO_LIMIT) {
103+
exchange = fetchAll(messagesToSend, client, portal);
104+
} else {
105+
exchange = fetchOptimisticCursored(messagesToSend, client, portal, fetchSize);
106+
}
107+
108+
if (prepareRequired) {
109+
110+
exchange = exchange.doOnNext(message -> {
111+
112+
if (message == ParseComplete.INSTANCE) {
113+
cache.put(binding, query, name);
114+
}
115+
});
116+
}
117+
118+
return exchange.doOnDiscard(ReferenceCounted.class, ReferenceCountUtil::release).filter(RESULT_FRAME_FILTER).handle(factory::handleErrorResponse);
119+
}
120+
121+
/**
122+
* Execute the query and indicate to fetch all rows with the {@link Execute} message.
123+
*
124+
* @param messagesToSend the initial bind flow
125+
* @param client client to use
126+
* @param portal the portal
127+
* @return the resulting message stream
128+
*/
129+
private static Flux<BackendMessage> fetchAll(List<FrontendMessage.DirectEncoder> messagesToSend, Client client, String portal) {
130+
131+
messagesToSend.add(new Execute(portal, NO_LIMIT));
132+
messagesToSend.add(new Close(portal, PORTAL));
133+
messagesToSend.add(Sync.INSTANCE);
134+
135+
return client.exchange(Mono.just(new CompositeFrontendMessage(messagesToSend)))
136+
.as(Operators::discardOnCancel);
137+
}
138+
139+
/**
140+
* Execute a contiguous query and indicate to fetch rows in chunks with the {@link Execute} message. Uses {@link Flush}-based synchronization that creates a cursor. Note that flushing keeps the
141+
* cursor open even with implicit transactions and this method may not work with newer pgpool implementations.
142+
*
143+
* @param messagesToSend the messages to send
144+
* @param client client to use
145+
* @param portal the portal
146+
* @param fetchSize fetch size per roundtrip
147+
* @return the resulting message stream
148+
*/
149+
private static Flux<BackendMessage> fetchOptimisticCursored(List<FrontendMessage.DirectEncoder> messagesToSend, Client client, String portal, int fetchSize) {
150+
151+
DirectProcessor<FrontendMessage> requestsProcessor = DirectProcessor.create();
152+
FluxSink<FrontendMessage> requestsSink = requestsProcessor.sink();
153+
AtomicBoolean isCanceled = new AtomicBoolean(false);
154+
155+
messagesToSend.add(new Execute(portal, fetchSize));
156+
messagesToSend.add(Flush.INSTANCE);
157+
158+
return client.exchange(Flux.<FrontendMessage>just(new CompositeFrontendMessage(messagesToSend)).concatWith(requestsProcessor))
159+
.handle((BackendMessage message, SynchronousSink<BackendMessage> sink) -> {
160+
161+
if (message instanceof CommandComplete) {
162+
requestsSink.next(new Close(portal, PORTAL));
163+
requestsSink.next(Sync.INSTANCE);
164+
requestsSink.complete();
165+
sink.next(message);
166+
} else if (message instanceof ErrorResponse) {
167+
requestsSink.next(Sync.INSTANCE);
168+
requestsSink.complete();
169+
sink.next(message);
170+
} else if (message instanceof PortalSuspended) {
171+
if (isCanceled.get()) {
172+
requestsSink.next(new Close(portal, PORTAL));
173+
requestsSink.next(Sync.INSTANCE);
174+
requestsSink.complete();
175+
} else {
176+
requestsSink.next(new Execute(portal, fetchSize));
177+
requestsSink.next(Flush.INSTANCE);
178+
}
179+
} else {
180+
sink.next(message);
181+
}
182+
}).doFinally(ignore -> requestsSink.complete())
183+
.as(flux -> Operators.discardOnCancel(flux, () -> isCanceled.set(true)));
184+
}
185+
186+
}

0 commit comments

Comments
 (0)