Skip to content

Commit 42da082

Browse files
authored
fix: prepared statements with simple protocol not working (#628)
- fix: prepared statements with simple protocol not working #604
1 parent 7d58743 commit 42da082

File tree

13 files changed

+198
-43
lines changed

13 files changed

+198
-43
lines changed
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
[general]
2+
prepared_statements = "full"
3+
4+
[[databases]]
5+
name = "pgdog"
6+
host = "127.0.0.1"
7+
8+
[admin]
9+
password = "pgdog"
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
[[users]]
2+
database = "pgdog"
3+
name = "pgdog"
4+
password = "pgdog"
Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,33 @@
1-
from globals import normal_sync
2-
import pytest
1+
from globals import normal_sync, no_out_of_sync, admin
2+
from multiprocessing import Pool
3+
import time
34

45

5-
@pytest.mark.skip(reason="These are not working")
6-
def test_prepared_full():
7-
for _ in range(5):
8-
conn = normal_sync()
9-
conn.autocommit = True
10-
11-
cur = conn.cursor()
12-
cur.execute("PREPARE test_stmt AS SELECT 1")
13-
cur.execute("PREPARE test_stmt AS SELECT 2")
14-
6+
def run_prepare_execute(worker_id):
157
conn = normal_sync()
168
conn.autocommit = True
17-
for _ in range(5):
18-
cur = conn.cursor()
19-
20-
for i in range(5):
21-
cur.execute(f"PREPARE test_stmt_{i} AS SELECT $1::bigint")
22-
cur.execute(f"EXECUTE test_stmt_{i}({i})")
23-
result = cur.fetchone()
24-
assert result[0] == i
25-
conn.commit()
9+
cur = conn.cursor()
10+
11+
stmt_name = f"stmt_{worker_id}"
12+
cur.execute(f"PREPARE {stmt_name} AS SELECT $1::bigint * 2")
13+
time.sleep(0.01)
14+
15+
for i in range(100):
16+
cur.execute(f"EXECUTE {stmt_name}({i})")
17+
result = cur.fetchone()
18+
assert result[0] == i * 2
19+
time.sleep(0.01)
20+
21+
cur.execute(f"DEALLOCATE {stmt_name}")
22+
conn.close()
23+
return True
24+
25+
26+
def test_prepare_execute_parallel():
27+
admin().execute("SET prepared_statements TO 'full'")
28+
29+
with Pool(5) as pool:
30+
results = pool.map(run_prepare_execute, range(5))
31+
assert all(results)
32+
no_out_of_sync()
33+
admin().execute("RELOAD")

pgdog/src/backend/pool/guard.rs

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -727,4 +727,53 @@ mod test {
727727
assert!(server.needs_drain());
728728
assert!(server.in_transaction());
729729
}
730+
731+
#[tokio::test]
732+
async fn test_cleanup_syncs_prepared_statements() {
733+
crate::logger();
734+
735+
let mut server = Guard::new(
736+
Pool::new_test(),
737+
Box::new(test_server().await),
738+
Instant::now(),
739+
);
740+
741+
assert!(!server.sync_prepared());
742+
743+
server
744+
.send(&vec![Query::new("PREPARE test_stmt AS SELECT $1::bigint").into()].into())
745+
.await
746+
.unwrap();
747+
748+
for c in ['C', 'Z'] {
749+
let msg = server.read().await.unwrap();
750+
assert_eq!(msg.code(), c);
751+
}
752+
753+
assert!(
754+
server.sync_prepared(),
755+
"sync_prepared flag should be set after PREPARE command"
756+
);
757+
758+
let mut guard = server;
759+
let mut server = guard.server.take().unwrap();
760+
let cleanup = Cleanup::new(&guard, &mut server);
761+
762+
Guard::cleanup_internal(&mut server, cleanup, ConnectionRecovery::Recover)
763+
.await
764+
.unwrap();
765+
766+
assert!(
767+
!server.sync_prepared(),
768+
"sync_prepared flag should be cleared after cleanup"
769+
);
770+
771+
assert!(
772+
server.prepared_statements_mut().contains("test_stmt"),
773+
"Statement should be in local cache after sync"
774+
);
775+
776+
let one: Vec<i32> = server.fetch_all("SELECT 1").await.unwrap();
777+
assert_eq!(one[0], 1);
778+
}
730779
}

pgdog/src/backend/prepared_statements.rs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,16 @@ impl PreparedStatements {
174174
self.state.add('3');
175175
}
176176
}
177-
ProtocolMessage::Prepare { .. } => (),
177+
ProtocolMessage::Prepare { name, .. } => {
178+
if self.contains(name) {
179+
return Ok(HandleResult::Drop);
180+
} else {
181+
self.parses.push_back(name.clone());
182+
self.state.add_ignore('C');
183+
self.state.add_ignore('Z');
184+
return Ok(HandleResult::Forward);
185+
}
186+
}
178187
ProtocolMessage::CopyDone(_) => {
179188
self.state.action('c')?;
180189
}
@@ -218,7 +227,7 @@ impl PreparedStatements {
218227
self.describes.pop_front();
219228
}
220229

221-
'1' => {
230+
'1' | 'C' => {
222231
if let Some(name) = self.parses.pop_front() {
223232
self.prepared(&name);
224233
}

pgdog/src/backend/server.rs

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -721,6 +721,7 @@ impl Server {
721721

722722
let count = self.prepared_statements.len();
723723
self.stats.set_prepared_statements(count);
724+
self.sync_prepared = false;
724725

725726
Ok(())
726727
}
@@ -2304,4 +2305,40 @@ pub mod test {
23042305
final_idle_time,
23052306
);
23062307
}
2308+
2309+
#[tokio::test]
2310+
async fn test_prepare_forces_sync_prepared_flag() {
2311+
let mut server = test_server().await;
2312+
2313+
assert!(!server.sync_prepared());
2314+
2315+
server
2316+
.send(&vec![Query::new("PREPARE test_stmt AS SELECT $1::bigint").into()].into())
2317+
.await
2318+
.unwrap();
2319+
2320+
for c in ['C', 'Z'] {
2321+
let msg = server.read().await.unwrap();
2322+
assert_eq!(msg.code(), c);
2323+
}
2324+
2325+
assert!(
2326+
server.sync_prepared(),
2327+
"sync_prepared flag should be set after PREPARE command"
2328+
);
2329+
2330+
server.sync_prepared_statements().await.unwrap();
2331+
2332+
assert!(
2333+
!server.sync_prepared(),
2334+
"sync_prepared flag should be cleared after sync_prepared_statements()"
2335+
);
2336+
2337+
server.execute("SELECT 1").await.unwrap();
2338+
2339+
assert!(
2340+
!server.sync_prepared(),
2341+
"sync_prepared flag should remain false after regular queries"
2342+
);
2343+
}
23072344
}

pgdog/src/frontend/client/query_engine/query.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ impl QueryEngine {
4848
}
4949

5050
if let Some(sql) = route.rewritten_sql() {
51-
match context.client_request.rewrite(sql) {
51+
match context.client_request.rewrite(&[Query::new(sql).into()]) {
5252
Ok(()) => (),
5353
Err(crate::net::Error::OnlySimpleForRewrites) => {
5454
context.client_request.rewrite_prepared(

pgdog/src/frontend/client_request.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use regex::Regex;
77
use crate::{
88
frontend::router::parser::RewritePlan,
99
net::{
10-
messages::{Bind, CopyData, Protocol, Query},
10+
messages::{Bind, CopyData, Protocol},
1111
Error, Flush, ProtocolMessage,
1212
},
1313
stats::memory::MemoryUsage,
@@ -180,12 +180,12 @@ impl ClientRequest {
180180
}
181181

182182
/// Rewrite query in buffer.
183-
pub fn rewrite(&mut self, query: &str) -> Result<(), Error> {
183+
pub fn rewrite(&mut self, request: &[ProtocolMessage]) -> Result<(), Error> {
184184
if self.messages.iter().any(|c| c.code() != 'Q') {
185185
return Err(Error::OnlySimpleForRewrites);
186186
}
187187
self.messages.clear();
188-
self.messages.push(Query::new(query).into());
188+
self.messages.extend(request.to_vec());
189189
Ok(())
190190
}
191191

@@ -316,7 +316,7 @@ impl DerefMut for ClientRequest {
316316

317317
#[cfg(test)]
318318
mod test {
319-
use crate::net::{Describe, Execute, Parse, Sync};
319+
use crate::net::{Describe, Execute, Parse, Query, Sync};
320320

321321
use super::*;
322322

pgdog/src/frontend/prepared_statements/mod.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ impl PreparedStatements {
102102
self.global.read().rewrite_plan(name)
103103
}
104104

105+
/// Set rewrite plan for UPDATE / INSERT statement
106+
/// rewrites used in rewritten cross-shard queries.
105107
pub fn update_and_set_rewrite_plan(
106108
&mut self,
107109
name: &str,
@@ -117,6 +119,14 @@ impl PreparedStatements {
117119
self.local.get(name)
118120
}
119121

122+
/// Get globally-prepared statement by local name.
123+
pub fn parse(&self, name: &str) -> Option<Parse> {
124+
self.local
125+
.get(name)
126+
.map(|name| self.global.read().parse(name))
127+
.flatten()
128+
}
129+
120130
/// Number of prepared statements in the local cache.
121131
pub fn len_local(&self) -> usize {
122132
self.local.len()

pgdog/src/frontend/router/parser/command.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use super::*;
22
use crate::{
33
frontend::{client::TransactionType, BufferedQuery},
4-
net::parameter::ParameterValue,
4+
net::{parameter::ParameterValue, ProtocolMessage},
55
};
66
use lazy_static::lazy_static;
77

@@ -28,7 +28,7 @@ pub enum Command {
2828
value: ParameterValue,
2929
},
3030
PreparedStatement(Prepare),
31-
Rewrite(String),
31+
Rewrite(Vec<ProtocolMessage>),
3232
Shards(usize),
3333
Deallocate,
3434
Discard {

0 commit comments

Comments
 (0)