Skip to content

Commit 57d9a7c

Browse files
jaggederestCopilotlevkk
authored
Rewrite bug fixes - old cache key, more efficient update (#516)
* Rewrite bug fixes - old cache key, more efficient update * Update pgdog/src/frontend/prepared_statements/mod.rs Co-authored-by: Copilot <[email protected]> * Update pgdog/src/frontend/client_request.rs Co-authored-by: Lev Kokotov <[email protected]> * Skip removing and readding cache item, just modify it in place * skip cloning name if passed by reference * skip cloning name if passed by reference * format * cache key fixes * Update cache removal test to handle close_unused(0) clearing --------- Co-authored-by: Copilot <[email protected]> Co-authored-by: Lev Kokotov <[email protected]>
1 parent 375ad58 commit 57d9a7c

File tree

5 files changed

+246
-53
lines changed

5 files changed

+246
-53
lines changed

integration/rust/tests/integration/prepared.rs

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,136 @@ async fn test_prepared_cache() {
7979
});
8080
}
8181

82+
#[tokio::test]
83+
async fn test_prepared_cache_respects_limit() {
84+
let admin = admin_sqlx().await;
85+
86+
// Start from a clean state so results aren't influenced by previous tests.
87+
admin.execute("RECONNECT").await.unwrap();
88+
admin
89+
.execute("SET prepared_statements_limit TO 100")
90+
.await
91+
.unwrap();
92+
93+
// Run the average helper query a few times to populate the cache.
94+
for _ in 0..3 {
95+
let pools = connections_sqlx().await;
96+
for pool in &pools {
97+
sqlx::query("/* test_prepared_cache_rust */ SELECT $1")
98+
.bind(5)
99+
.fetch_one(pool)
100+
.await
101+
.unwrap();
102+
}
103+
104+
for pool in pools {
105+
pool.close().await;
106+
}
107+
}
108+
109+
let mut prepared = admin.fetch_all("SHOW PREPARED").await.unwrap();
110+
prepared.retain(|row| {
111+
row.get::<String, &str>("statement")
112+
.contains("/* test_prepared_cache_rust")
113+
});
114+
assert_eq!(
115+
prepared.len(),
116+
2,
117+
"expected the original statement plus its helper"
118+
);
119+
120+
// Tighten the cache limit and ensure unused statements are evicted.
121+
admin
122+
.execute("SET prepared_statements_limit TO 1")
123+
.await
124+
.unwrap();
125+
126+
let mut prepared = admin.fetch_all("SHOW PREPARED").await.unwrap();
127+
prepared.retain(|row| {
128+
row.get::<String, &str>("statement")
129+
.contains("/* test_prepared_cache_rust")
130+
});
131+
assert!(
132+
prepared.len() <= 1,
133+
"expected helper statements to be evicted when limit drops"
134+
);
135+
136+
admin.execute("RELOAD").await.unwrap();
137+
}
138+
139+
#[tokio::test]
140+
async fn test_prepared_cache_helper_evicted_on_close() {
141+
let admin = admin_sqlx().await;
142+
admin.execute("RECONNECT").await.unwrap();
143+
admin
144+
.execute("SET prepared_statements_limit TO 100")
145+
.await
146+
.unwrap();
147+
148+
let mut pools = connections_sqlx().await;
149+
let sharded = pools.remove(1);
150+
let primary = pools.remove(0);
151+
152+
// Build a deterministic dataset per shard to trigger AVG rewrite.
153+
for shard in [0, 1] {
154+
let drop = format!(
155+
"/* pgdog_shard: {} */ DROP TABLE IF EXISTS avg_helper_cleanup",
156+
shard
157+
);
158+
sharded.execute(drop.as_str()).await.ok();
159+
160+
let create = format!(
161+
"/* pgdog_shard: {} */ CREATE TABLE avg_helper_cleanup(price DOUBLE PRECISION)",
162+
shard
163+
);
164+
sharded.execute(create.as_str()).await.unwrap();
165+
}
166+
167+
sharded
168+
.execute("/* pgdog_shard: 0 */ INSERT INTO avg_helper_cleanup(price) VALUES (10.0), (14.0)")
169+
.await
170+
.unwrap();
171+
sharded
172+
.execute("/* pgdog_shard: 1 */ INSERT INTO avg_helper_cleanup(price) VALUES (18.0), (22.0)")
173+
.await
174+
.unwrap();
175+
176+
sqlx::query("/* test_avg_helper_cleanup */ SELECT AVG(price) FROM avg_helper_cleanup")
177+
.fetch_one(&sharded)
178+
.await
179+
.unwrap();
180+
181+
// Clean up tables so subsequent tests start fresh.
182+
for shard in [0, 1] {
183+
let drop = format!(
184+
"/* pgdog_shard: {} */ DROP TABLE IF EXISTS avg_helper_cleanup",
185+
shard
186+
);
187+
sharded.execute(drop.as_str()).await.ok();
188+
}
189+
190+
sharded.close().await;
191+
primary.close().await;
192+
193+
let prepared = admin
194+
.fetch_all("SHOW PREPARED")
195+
.await
196+
.unwrap()
197+
.into_iter()
198+
.filter(|row| {
199+
row.get::<String, &str>("statement")
200+
.contains("test_avg_helper_cleanup")
201+
})
202+
.collect::<Vec<_>>();
203+
204+
assert!(
205+
prepared.is_empty(),
206+
"helper rewrite statements should be evicted once the connection closes"
207+
);
208+
209+
admin.execute("RELOAD").await.unwrap();
210+
}
211+
82212
#[tokio::test]
83213
async fn test_prepard_cache_eviction() {
84214
let admin = admin_sqlx().await;

pgdog/src/frontend/client_request.rs

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -180,11 +180,7 @@ impl ClientRequest {
180180
for message in self.messages.iter_mut() {
181181
if let ProtocolMessage::Parse(parse) = message {
182182
parse.set_query(query);
183-
let name = parse.name().to_owned();
184-
let _ = prepared.update_query(&name, query);
185-
if !plan.is_noop() {
186-
prepared.set_rewrite_plan(&name, plan.clone());
187-
}
183+
prepared.update_and_set_rewrite_plan(&parse.name(), query, plan.clone());
188184
updated = true;
189185
}
190186
}

pgdog/src/frontend/prepared_statements/global_cache.rs

Lines changed: 99 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,11 @@ fn global_name(counter: usize) -> String {
1717
pub struct Statement {
1818
parse: Parse,
1919
row_description: Option<RowDescription>,
20+
#[allow(dead_code)]
2021
version: usize,
2122
rewrite_plan: Option<RewritePlan>,
23+
cache_key: CacheKey,
24+
evict_on_close: bool,
2225
}
2326

2427
impl MemoryUsage for Statement {
@@ -30,8 +33,8 @@ impl MemoryUsage for Statement {
3033
} else {
3134
0
3235
}
33-
// Rewrite plans are small; treat as zero-cost for now.
34-
+ 0
36+
+ self.cache_key.memory_usage()
37+
+ self.evict_on_close.memory_usage()
3538
}
3639
}
3740

@@ -41,11 +44,7 @@ impl Statement {
4144
}
4245

4346
fn cache_key(&self) -> CacheKey {
44-
CacheKey {
45-
query: self.parse.query_ref(),
46-
data_types: self.parse.data_types_ref(),
47-
version: self.version,
48-
}
47+
self.cache_key.clone()
4948
}
5049
}
5150

@@ -147,14 +146,14 @@ impl GlobalCache {
147146
let name = global_name(self.counter);
148147
let parse = parse.rename(&name);
149148

150-
let parse_key = CacheKey {
149+
let cache_key = CacheKey {
151150
query: parse.query_ref(),
152151
data_types: parse.data_types_ref(),
153152
version: 0,
154153
};
155154

156155
self.statements.insert(
157-
parse_key,
156+
cache_key.clone(),
158157
CachedStmt {
159158
counter: self.counter,
160159
used: 1,
@@ -168,6 +167,8 @@ impl GlobalCache {
168167
row_description: None,
169168
version: 0,
170169
rewrite_plan: None,
170+
cache_key,
171+
evict_on_close: false,
171172
},
172173
);
173174

@@ -205,6 +206,8 @@ impl GlobalCache {
205206
row_description: None,
206207
version: self.versions,
207208
rewrite_plan: None,
209+
cache_key: key,
210+
evict_on_close: false,
208211
},
209212
);
210213

@@ -221,27 +224,27 @@ impl GlobalCache {
221224
}
222225
}
223226

224-
pub fn update_query(&mut self, name: &str, sql: &str) -> bool {
227+
pub fn update_and_set_rewrite_plan(
228+
&mut self,
229+
name: &str,
230+
sql: &str,
231+
plan: RewritePlan,
232+
) -> bool {
225233
if let Some(statement) = self.names.get_mut(name) {
226-
let old_key = statement.cache_key();
227-
let cached = self.statements.remove(&old_key);
228234
statement.parse.set_query(sql);
229-
let new_key = statement.cache_key();
230-
if let Some(entry) = cached {
231-
self.statements.insert(new_key, entry);
235+
if !plan.is_noop() {
236+
statement.evict_on_close = !plan.helpers().is_empty();
237+
statement.rewrite_plan = Some(plan);
238+
} else {
239+
statement.evict_on_close = false;
240+
statement.rewrite_plan = None;
232241
}
233242
true
234243
} else {
235244
false
236245
}
237246
}
238247

239-
pub fn set_rewrite_plan(&mut self, name: &str, plan: RewritePlan) {
240-
if let Some(statement) = self.names.get_mut(name) {
241-
statement.rewrite_plan = Some(plan);
242-
}
243-
}
244-
245248
pub fn rewrite_plan(&self, name: &str) -> Option<RewritePlan> {
246249
self.names.get(name).and_then(|s| s.rewrite_plan.clone())
247250
}
@@ -290,27 +293,34 @@ impl GlobalCache {
290293

291294
/// Close prepared statement.
292295
pub fn close(&mut self, name: &str, capacity: usize) -> bool {
293-
let used = if let Some(stmt) = self.names.get(name) {
294-
if let Some(stmt) = self.statements.get_mut(&stmt.cache_key()) {
295-
stmt.used = stmt.used.saturating_sub(1);
296-
stmt.used > 0
297-
} else {
298-
false
296+
if let Some(statement) = self.names.get(name) {
297+
let key = statement.cache_key();
298+
let mut used_remaining = None;
299+
300+
if let Some(entry) = self.statements.get_mut(&key) {
301+
entry.used = entry.used.saturating_sub(1);
302+
used_remaining = Some(entry.used);
303+
if entry.used == 0 && (statement.evict_on_close || self.len() > capacity) {
304+
self.remove(name);
305+
return true;
306+
}
299307
}
300-
} else {
301-
false
302-
};
303308

304-
if !used && self.len() > capacity {
305-
self.remove(name);
306-
true
307-
} else {
308-
false
309+
return used_remaining.map(|u| u > 0).unwrap_or(false);
309310
}
311+
312+
false
310313
}
311314

312315
/// Close all unused statements exceeding capacity.
313316
pub fn close_unused(&mut self, capacity: usize) -> usize {
317+
if capacity == 0 {
318+
let removed = self.statements.len();
319+
self.statements.clear();
320+
self.names.clear();
321+
return removed;
322+
}
323+
314324
let mut remove = self.statements.len() as i64 - capacity as i64;
315325
let mut to_remove = vec![];
316326
for stmt in self.statements.values() {
@@ -409,7 +419,16 @@ mod test {
409419
names.push(name);
410420
}
411421

412-
assert_eq!(cache.close_unused(0), 0);
422+
assert_eq!(cache.close_unused(0), 25);
423+
assert!(cache.is_empty());
424+
425+
names.clear();
426+
for stmt in 0..25 {
427+
let parse = Parse::named("__sqlx_1", format!("SELECT {}", stmt));
428+
let (new, name) = cache.insert(&parse);
429+
assert!(new);
430+
names.push(name);
431+
}
413432

414433
for name in &names[0..5] {
415434
assert!(!cache.close(name, 25)); // Won't close because
@@ -422,4 +441,48 @@ mod test {
422441
assert_eq!(cache.close_unused(19), 0);
423442
assert_eq!(cache.len(), 20);
424443
}
444+
445+
#[test]
446+
fn test_close_unused_zero_clears_all_entries() {
447+
let mut cache = GlobalCache::default();
448+
449+
for idx in 0..5 {
450+
let parse = Parse::named("test", format!("SELECT {}", idx));
451+
let (_is_new, _name) = cache.insert(&parse);
452+
}
453+
454+
assert!(cache.len() > 0);
455+
456+
let removed = cache.close_unused(0);
457+
assert_eq!(removed, 5);
458+
assert!(cache.is_empty());
459+
}
460+
461+
#[test]
462+
fn test_update_query_reuses_cache_key() {
463+
let mut cache = GlobalCache::default();
464+
let parse = Parse::named("__sqlx_1", "SELECT 1");
465+
let (is_new, name) = cache.insert(&parse);
466+
assert!(is_new);
467+
468+
assert!(cache.update_and_set_rewrite_plan(
469+
&name,
470+
"SELECT 1 ORDER BY 1",
471+
RewritePlan::default()
472+
));
473+
474+
let key = cache
475+
.statements()
476+
.keys()
477+
.next()
478+
.expect("statement key missing");
479+
assert_eq!(key.query().unwrap(), "SELECT 1");
480+
assert_eq!(cache.query(&name).unwrap(), "SELECT 1 ORDER BY 1");
481+
482+
let parse_again = Parse::named("__sqlx_2", "SELECT 1");
483+
let (is_new_again, reused_name) = cache.insert(&parse_again);
484+
assert!(!is_new_again);
485+
assert_eq!(reused_name, name);
486+
assert_eq!(cache.len(), 1);
487+
}
425488
}

0 commit comments

Comments
 (0)