1919import io .netty .buffer .ByteBuf ;
2020import io .netty .util .ReferenceCountUtil ;
2121import io .netty .util .ReferenceCounted ;
22+ import io .r2dbc .postgresql .api .ErrorDetails ;
2223import io .r2dbc .postgresql .client .Binding ;
2324import io .r2dbc .postgresql .client .Client ;
2425import io .r2dbc .postgresql .client .ExtendedQueryMessageFlow ;
4546import io .r2dbc .postgresql .util .Operators ;
4647import reactor .core .publisher .Flux ;
4748import reactor .core .publisher .FluxSink ;
48- import reactor .core .publisher .Mono ;
4949import reactor .core .publisher .SynchronousSink ;
5050import reactor .core .publisher .UnicastProcessor ;
51+ import reactor .util .annotation .Nullable ;
5152import reactor .util .concurrent .Queues ;
5253
5354import java .util .ArrayList ;
55+ import java .util .Arrays ;
56+ import java .util .Collection ;
5457import java .util .List ;
5558import java .util .concurrent .atomic .AtomicBoolean ;
59+ import java .util .concurrent .atomic .AtomicInteger ;
60+ import java .util .function .BiConsumer ;
5661import java .util .function .Predicate ;
5762
5863import static io .r2dbc .postgresql .message .frontend .Execute .NO_LIMIT ;
@@ -87,92 +92,81 @@ public static Flux<BackendMessage> runQuery(ConnectionResources resources, Excep
8792 StatementCache cache = resources .getStatementCache ();
8893 Client client = resources .getClient ();
8994
90- String name = cache .getName (binding , query );
9195 String portal = resources .getPortalNameSupplier ().get ();
92- boolean prepareRequired = cache .requiresPrepare (binding , query );
93-
94- List <FrontendMessage .DirectEncoder > messagesToSend = new ArrayList <>(6 );
95-
96- if (prepareRequired ) {
97- messagesToSend .add (new Parse (name , binding .getParameterTypes (), query ));
98- }
99-
100- Bind bind = new Bind (portal , binding .getParameterFormats (), values , ExtendedQueryMessageFlow .resultFormat (resources .getConfiguration ().isForceBinary ()), name );
101-
102- messagesToSend .add (bind );
103- messagesToSend .add (new Describe (portal , PORTAL ));
10496
10597 Flux <BackendMessage > exchange ;
10698 boolean compatibilityMode = resources .getConfiguration ().isCompatibilityMode ();
10799 boolean implicitTransactions = resources .getClient ().getTransactionStatus () == TransactionStatus .IDLE ;
108100
101+ ExtendedFlowOperator operator = new ExtendedFlowOperator (query , binding , cache , values , portal , resources .getConfiguration ().isForceBinary ());
102+
109103 if (compatibilityMode ) {
110104
111105 if (fetchSize == NO_LIMIT || implicitTransactions ) {
112- exchange = fetchAll (messagesToSend , client , portal );
106+ exchange = fetchAll (operator , client , portal );
113107 } else {
114- exchange = fetchCursoredWithSync (messagesToSend , client , portal , fetchSize );
108+ exchange = fetchCursoredWithSync (operator , client , portal , fetchSize );
115109 }
116110 } else {
117111
118112 if (fetchSize == NO_LIMIT ) {
119- exchange = fetchAll (messagesToSend , client , portal );
113+ exchange = fetchAll (operator , client , portal );
120114 } else {
121- exchange = fetchCursoredWithFlush (messagesToSend , client , portal , fetchSize );
115+ exchange = fetchCursoredWithFlush (operator , client , portal , fetchSize );
122116 }
123117 }
124118
125- if (prepareRequired ) {
126-
127- exchange = exchange .doOnNext (message -> {
119+ exchange = exchange .doOnNext (message -> {
128120
129- if (message == ParseComplete .INSTANCE ) {
130- cache .put (binding , query , name );
131- }
132- });
133- }
121+ if (message == ParseComplete .INSTANCE ) {
122+ operator .hydrateStatementCache ();
123+ }
124+ });
134125
135126 return exchange .doOnSubscribe (it -> QueryLogger .logQuery (client .getContext (), query )).doOnDiscard (ReferenceCounted .class , ReferenceCountUtil ::release ).filter (RESULT_FRAME_FILTER ).handle (factory ::handleErrorResponse );
136127 }
137128
138129 /**
139130 * Execute the query and indicate to fetch all rows with the {@link Execute} message.
140131 *
141- * @param messagesToSend the initial bind flow
142- * @param client client to use
143- * @param portal the portal
132+ * @param operator the flow operator
133+ * @param client client to use
134+ * @param portal the portal
144135 * @return the resulting message stream
145136 */
146- private static Flux <BackendMessage > fetchAll (List < FrontendMessage . DirectEncoder > messagesToSend , Client client , String portal ) {
137+ private static Flux <BackendMessage > fetchAll (ExtendedFlowOperator operator , Client client , String portal ) {
147138
148- messagesToSend . add ( new Execute ( portal , NO_LIMIT ));
149- messagesToSend . add ( new Close ( portal , PORTAL ) );
150- messagesToSend . add ( Sync .INSTANCE );
139+ UnicastProcessor < FrontendMessage > requestsProcessor = UnicastProcessor . create ( Queues .< FrontendMessage > small (). get ( ));
140+ FluxSink < FrontendMessage > requestsSink = requestsProcessor . sink ( );
141+ MessageFactory factory = () -> operator . getMessages ( Arrays . asList ( new Execute ( portal , NO_LIMIT ), new Close ( portal , PORTAL ), Sync .INSTANCE ) );
151142
152- return client .exchange (Mono .just (new CompositeFrontendMessage (messagesToSend )))
143+ return client .exchange (operator .takeUntil (), Flux .<FrontendMessage >just (new CompositeFrontendMessage (factory .createMessages ())).concatWith (requestsProcessor ))
144+ .handle (handleReprepare (requestsSink , operator , factory ))
145+ .doFinally (ignore -> operator .close (requestsSink ))
153146 .as (Operators ::discardOnCancel );
154147 }
155148
156149 /**
157150 * Execute a chunked query and indicate to fetch rows in chunks with the {@link Execute} message.
158151 *
159- * @param messagesToSend the messages to send
160- * @param client client to use
161- * @param portal the portal
162- * @param fetchSize fetch size per roundtrip
152+ * @param operator the flow operator
153+ * @param client client to use
154+ * @param portal the portal
155+ * @param fetchSize fetch size per roundtrip
163156 * @return the resulting message stream
164157 */
165- private static Flux <BackendMessage > fetchCursoredWithSync (List < FrontendMessage . DirectEncoder > messagesToSend , Client client , String portal , int fetchSize ) {
158+ private static Flux <BackendMessage > fetchCursoredWithSync (ExtendedFlowOperator operator , Client client , String portal , int fetchSize ) {
166159
167160 UnicastProcessor <FrontendMessage > requestsProcessor = UnicastProcessor .create (Queues .<FrontendMessage >small ().get ());
168161 FluxSink <FrontendMessage > requestsSink = requestsProcessor .sink ();
169162 AtomicBoolean isCanceled = new AtomicBoolean (false );
170163 AtomicBoolean done = new AtomicBoolean (false );
171164
172- messagesToSend . add ( new Execute (portal , fetchSize ));
173- messagesToSend . add ( Sync . INSTANCE );
165+ MessageFactory factory = () -> operator . getMessages ( Arrays . asList ( new Execute (portal , fetchSize ), Sync . INSTANCE ));
166+ Predicate < BackendMessage > takeUntil = operator . takeUntil ( );
174167
175- return client .exchange (it -> done .get () && it instanceof ReadyForQuery , Flux .<FrontendMessage >just (new CompositeFrontendMessage (messagesToSend )).concatWith (requestsProcessor ))
168+ return client .exchange (it -> done .get () && takeUntil .test (it ), Flux .<FrontendMessage >just (new CompositeFrontendMessage (factory .createMessages ())).concatWith (requestsProcessor ))
169+ .handle (handleReprepare (requestsSink , operator , factory ))
176170 .handle ((BackendMessage message , SynchronousSink <BackendMessage > sink ) -> {
177171
178172 if (message instanceof CommandComplete ) {
@@ -211,30 +205,30 @@ private static Flux<BackendMessage> fetchCursoredWithSync(List<FrontendMessage.D
211205 } else {
212206 sink .next (message );
213207 }
214- }).doFinally (ignore -> requestsSink . complete ( ))
208+ }).doFinally (ignore -> operator . close ( requestsSink ))
215209 .as (flux -> Operators .discardOnCancel (flux , () -> isCanceled .set (true )));
216210 }
217211
218212 /**
219213 * Execute a contiguous query and indicate to fetch rows in chunks with the {@link Execute} message. Uses {@link Flush}-based synchronization that creates a cursor. Note that flushing keeps the
220214 * cursor open even with implicit transactions and this method may not work with newer pgpool implementations.
221215 *
222- * @param messagesToSend the messages to send
223- * @param client client to use
224- * @param portal the portal
225- * @param fetchSize fetch size per roundtrip
216+ * @param operator the flow operator
217+ * @param client client to use
218+ * @param portal the portal
219+ * @param fetchSize fetch size per roundtrip
226220 * @return the resulting message stream
227221 */
228- private static Flux <BackendMessage > fetchCursoredWithFlush (List < FrontendMessage . DirectEncoder > messagesToSend , Client client , String portal , int fetchSize ) {
222+ private static Flux <BackendMessage > fetchCursoredWithFlush (ExtendedFlowOperator operator , Client client , String portal , int fetchSize ) {
229223
230224 UnicastProcessor <FrontendMessage > requestsProcessor = UnicastProcessor .create (Queues .<FrontendMessage >small ().get ());
231225 FluxSink <FrontendMessage > requestsSink = requestsProcessor .sink ();
232226 AtomicBoolean isCanceled = new AtomicBoolean (false );
233227
234- messagesToSend .add (new Execute (portal , fetchSize ));
235- messagesToSend .add (Flush .INSTANCE );
228+ MessageFactory factory = () -> operator .getMessages (Arrays .asList (new Execute (portal , fetchSize ), Flush .INSTANCE ));
236229
237- return client .exchange (Flux .<FrontendMessage >just (new CompositeFrontendMessage (messagesToSend )).concatWith (requestsProcessor ))
230+ return client .exchange (operator .takeUntil (), Flux .<FrontendMessage >just (new CompositeFrontendMessage (factory .createMessages ())).concatWith (requestsProcessor ))
231+ .handle (handleReprepare (requestsSink , operator , factory ))
238232 .handle ((BackendMessage message , SynchronousSink <BackendMessage > sink ) -> {
239233
240234 if (message instanceof CommandComplete ) {
@@ -258,8 +252,154 @@ private static Flux<BackendMessage> fetchCursoredWithFlush(List<FrontendMessage.
258252 } else {
259253 sink .next (message );
260254 }
261- }).doFinally (ignore -> requestsSink . complete ( ))
255+ }).doFinally (ignore -> operator . close ( requestsSink ))
262256 .as (flux -> Operators .discardOnCancel (flux , () -> isCanceled .set (true )));
263257 }
264258
259+ private static BiConsumer <BackendMessage , SynchronousSink <BackendMessage >> handleReprepare (FluxSink <FrontendMessage > requests , ExtendedFlowOperator operator , MessageFactory messageFactory ) {
260+
261+ AtomicBoolean reprepared = new AtomicBoolean ();
262+
263+ return (message , sink ) -> {
264+
265+ if (message instanceof ErrorResponse && requiresReprepare ((ErrorResponse ) message ) && reprepared .compareAndSet (false , true )) {
266+
267+ operator .evictCachedStatement ();
268+
269+ List <FrontendMessage .DirectEncoder > messages = messageFactory .createMessages ();
270+ if (!messages .contains (Sync .INSTANCE )) {
271+ messages .add (0 , Sync .INSTANCE );
272+ }
273+ requests .next (new CompositeFrontendMessage (messages ));
274+ } else {
275+ sink .next (message );
276+ }
277+ };
278+ }
279+
280+ private static boolean requiresReprepare (ErrorResponse errorResponse ) {
281+
282+ ErrorDetails details = new ErrorDetails (errorResponse .getFields ());
283+ String code = details .getCode ();
284+
285+ // "prepared statement \"S_2\" does not exist"
286+ // INVALID_SQL_STATEMENT_NAME
287+ if ("26000" .equals (code )) {
288+ return true ;
289+ }
290+ // NOT_IMPLEMENTED
291+
292+ if (!"0A000" .equals (code )) {
293+ return false ;
294+ }
295+
296+ String routine = details .getRoutine ().orElse (null );
297+ // "cached plan must not change result type"
298+ return "RevalidateCachedQuery" .equals (routine ) // 9.2+
299+ || "RevalidateCachedPlan" .equals (routine ); // <= 9.1
300+ }
301+
302+ interface MessageFactory {
303+
304+ List <FrontendMessage .DirectEncoder > createMessages ();
305+
306+ }
307+
308+ /**
309+ * Operator to encapsulate common activity around the extended flow. Subclasses {@link AtomicInteger} to capture the number of ReadyForQuery frames.
310+ */
311+ static class ExtendedFlowOperator extends AtomicInteger {
312+
313+ private final String sql ;
314+
315+ private final Binding binding ;
316+
317+ @ Nullable
318+ private volatile String name ;
319+
320+ private final StatementCache cache ;
321+
322+ private final List <ByteBuf > values ;
323+
324+ private final String portal ;
325+
326+ private final boolean forceBinary ;
327+
328+ public ExtendedFlowOperator (String sql , Binding binding , StatementCache cache , List <ByteBuf > values , String portal , boolean forceBinary ) {
329+ this .sql = sql ;
330+ this .binding = binding ;
331+ this .cache = cache ;
332+ this .values = values ;
333+ this .portal = portal ;
334+ this .forceBinary = forceBinary ;
335+ set (1 );
336+ }
337+
338+ public void close (FluxSink <FrontendMessage > requests ) {
339+ requests .complete ();
340+ this .values .forEach (ReferenceCountUtil ::release );
341+ }
342+
343+ public void evictCachedStatement () {
344+
345+ incrementAndGet ();
346+
347+ synchronized (this ) {
348+ this .name = null ;
349+ }
350+ this .cache .evict (this .sql );
351+ }
352+
353+ public void hydrateStatementCache () {
354+ this .cache .put (this .binding , this .sql , getStatementName ());
355+ }
356+
357+ public Predicate <BackendMessage > takeUntil () {
358+ return m -> {
359+
360+ if (m instanceof ReadyForQuery ) {
361+ return decrementAndGet () <= 0 ;
362+ }
363+
364+ return false ;
365+ };
366+ }
367+
368+ private boolean isPrepareRequired () {
369+ return this .cache .requiresPrepare (this .binding , this .sql );
370+ }
371+
372+ public String getStatementName () {
373+ synchronized (this ) {
374+
375+ if (this .name == null ) {
376+ this .name = this .cache .getName (this .binding , this .sql );
377+ }
378+ return this .name ;
379+ }
380+ }
381+
382+ public List <FrontendMessage .DirectEncoder > getMessages (Collection <FrontendMessage .DirectEncoder > append ) {
383+ List <FrontendMessage .DirectEncoder > messagesToSend = new ArrayList <>(6 );
384+
385+ if (isPrepareRequired ()) {
386+ messagesToSend .add (new Parse (getStatementName (), this .binding .getParameterTypes (), this .sql ));
387+ }
388+
389+ for (ByteBuf value : this .values ) {
390+ value .readerIndex (0 );
391+ value .touch ("ExtendedFlowOperator" ).retain ();
392+ }
393+
394+ Bind bind = new Bind (this .portal , this .binding .getParameterFormats (), this .values , ExtendedQueryMessageFlow .resultFormat (this .forceBinary ), getStatementName ());
395+
396+ messagesToSend .add (bind );
397+ messagesToSend .add (new Describe (this .portal , PORTAL ));
398+ messagesToSend .addAll (append );
399+
400+ return messagesToSend ;
401+ }
402+
403+ }
404+
265405}
0 commit comments