@@ -17,8 +17,11 @@ fn global_name(counter: usize) -> String {
17
17
pub struct Statement {
18
18
parse : Parse ,
19
19
row_description : Option < RowDescription > ,
20
+ #[ allow( dead_code) ]
20
21
version : usize ,
21
22
rewrite_plan : Option < RewritePlan > ,
23
+ cache_key : CacheKey ,
24
+ evict_on_close : bool ,
22
25
}
23
26
24
27
impl MemoryUsage for Statement {
@@ -30,8 +33,8 @@ impl MemoryUsage for Statement {
30
33
} else {
31
34
0
32
35
}
33
- // Rewrite plans are small; treat as zero-cost for now.
34
- + 0
36
+ + self . cache_key . memory_usage ( )
37
+ + self . evict_on_close . memory_usage ( )
35
38
}
36
39
}
37
40
@@ -41,11 +44,7 @@ impl Statement {
41
44
}
42
45
43
46
fn cache_key ( & self ) -> CacheKey {
44
- CacheKey {
45
- query : self . parse . query_ref ( ) ,
46
- data_types : self . parse . data_types_ref ( ) ,
47
- version : self . version ,
48
- }
47
+ self . cache_key . clone ( )
49
48
}
50
49
}
51
50
@@ -147,14 +146,14 @@ impl GlobalCache {
147
146
let name = global_name ( self . counter ) ;
148
147
let parse = parse. rename ( & name) ;
149
148
150
- let parse_key = CacheKey {
149
+ let cache_key = CacheKey {
151
150
query : parse. query_ref ( ) ,
152
151
data_types : parse. data_types_ref ( ) ,
153
152
version : 0 ,
154
153
} ;
155
154
156
155
self . statements . insert (
157
- parse_key ,
156
+ cache_key . clone ( ) ,
158
157
CachedStmt {
159
158
counter : self . counter ,
160
159
used : 1 ,
@@ -168,6 +167,8 @@ impl GlobalCache {
168
167
row_description : None ,
169
168
version : 0 ,
170
169
rewrite_plan : None ,
170
+ cache_key,
171
+ evict_on_close : false ,
171
172
} ,
172
173
) ;
173
174
@@ -205,6 +206,8 @@ impl GlobalCache {
205
206
row_description : None ,
206
207
version : self . versions ,
207
208
rewrite_plan : None ,
209
+ cache_key : key,
210
+ evict_on_close : false ,
208
211
} ,
209
212
) ;
210
213
@@ -221,27 +224,27 @@ impl GlobalCache {
221
224
}
222
225
}
223
226
224
- pub fn update_query ( & mut self , name : & str , sql : & str ) -> bool {
227
+ pub fn update_and_set_rewrite_plan (
228
+ & mut self ,
229
+ name : & str ,
230
+ sql : & str ,
231
+ plan : RewritePlan ,
232
+ ) -> bool {
225
233
if let Some ( statement) = self . names . get_mut ( name) {
226
- let old_key = statement. cache_key ( ) ;
227
- let cached = self . statements . remove ( & old_key) ;
228
234
statement. parse . set_query ( sql) ;
229
- let new_key = statement. cache_key ( ) ;
230
- if let Some ( entry) = cached {
231
- self . statements . insert ( new_key, entry) ;
235
+ if !plan. is_noop ( ) {
236
+ statement. evict_on_close = !plan. helpers ( ) . is_empty ( ) ;
237
+ statement. rewrite_plan = Some ( plan) ;
238
+ } else {
239
+ statement. evict_on_close = false ;
240
+ statement. rewrite_plan = None ;
232
241
}
233
242
true
234
243
} else {
235
244
false
236
245
}
237
246
}
238
247
239
- pub fn set_rewrite_plan ( & mut self , name : & str , plan : RewritePlan ) {
240
- if let Some ( statement) = self . names . get_mut ( name) {
241
- statement. rewrite_plan = Some ( plan) ;
242
- }
243
- }
244
-
245
248
pub fn rewrite_plan ( & self , name : & str ) -> Option < RewritePlan > {
246
249
self . names . get ( name) . and_then ( |s| s. rewrite_plan . clone ( ) )
247
250
}
@@ -290,27 +293,34 @@ impl GlobalCache {
290
293
291
294
/// Close prepared statement.
292
295
pub fn close ( & mut self , name : & str , capacity : usize ) -> bool {
293
- let used = if let Some ( stmt) = self . names . get ( name) {
294
- if let Some ( stmt) = self . statements . get_mut ( & stmt. cache_key ( ) ) {
295
- stmt. used = stmt. used . saturating_sub ( 1 ) ;
296
- stmt. used > 0
297
- } else {
298
- false
296
+ if let Some ( statement) = self . names . get ( name) {
297
+ let key = statement. cache_key ( ) ;
298
+ let mut used_remaining = None ;
299
+
300
+ if let Some ( entry) = self . statements . get_mut ( & key) {
301
+ entry. used = entry. used . saturating_sub ( 1 ) ;
302
+ used_remaining = Some ( entry. used ) ;
303
+ if entry. used == 0 && ( statement. evict_on_close || self . len ( ) > capacity) {
304
+ self . remove ( name) ;
305
+ return true ;
306
+ }
299
307
}
300
- } else {
301
- false
302
- } ;
303
308
304
- if !used && self . len ( ) > capacity {
305
- self . remove ( name) ;
306
- true
307
- } else {
308
- false
309
+ return used_remaining. map ( |u| u > 0 ) . unwrap_or ( false ) ;
309
310
}
311
+
312
+ false
310
313
}
311
314
312
315
/// Close all unused statements exceeding capacity.
313
316
pub fn close_unused ( & mut self , capacity : usize ) -> usize {
317
+ if capacity == 0 {
318
+ let removed = self . statements . len ( ) ;
319
+ self . statements . clear ( ) ;
320
+ self . names . clear ( ) ;
321
+ return removed;
322
+ }
323
+
314
324
let mut remove = self . statements . len ( ) as i64 - capacity as i64 ;
315
325
let mut to_remove = vec ! [ ] ;
316
326
for stmt in self . statements . values ( ) {
@@ -409,7 +419,16 @@ mod test {
409
419
names. push ( name) ;
410
420
}
411
421
412
- assert_eq ! ( cache. close_unused( 0 ) , 0 ) ;
422
+ assert_eq ! ( cache. close_unused( 0 ) , 25 ) ;
423
+ assert ! ( cache. is_empty( ) ) ;
424
+
425
+ names. clear ( ) ;
426
+ for stmt in 0 ..25 {
427
+ let parse = Parse :: named ( "__sqlx_1" , format ! ( "SELECT {}" , stmt) ) ;
428
+ let ( new, name) = cache. insert ( & parse) ;
429
+ assert ! ( new) ;
430
+ names. push ( name) ;
431
+ }
413
432
414
433
for name in & names[ 0 ..5 ] {
415
434
assert ! ( !cache. close( name, 25 ) ) ; // Won't close because
@@ -422,4 +441,48 @@ mod test {
422
441
assert_eq ! ( cache. close_unused( 19 ) , 0 ) ;
423
442
assert_eq ! ( cache. len( ) , 20 ) ;
424
443
}
444
+
445
+ #[ test]
446
+ fn test_close_unused_zero_clears_all_entries ( ) {
447
+ let mut cache = GlobalCache :: default ( ) ;
448
+
449
+ for idx in 0 ..5 {
450
+ let parse = Parse :: named ( "test" , format ! ( "SELECT {}" , idx) ) ;
451
+ let ( _is_new, _name) = cache. insert ( & parse) ;
452
+ }
453
+
454
+ assert ! ( cache. len( ) > 0 ) ;
455
+
456
+ let removed = cache. close_unused ( 0 ) ;
457
+ assert_eq ! ( removed, 5 ) ;
458
+ assert ! ( cache. is_empty( ) ) ;
459
+ }
460
+
461
+ #[ test]
462
+ fn test_update_query_reuses_cache_key ( ) {
463
+ let mut cache = GlobalCache :: default ( ) ;
464
+ let parse = Parse :: named ( "__sqlx_1" , "SELECT 1" ) ;
465
+ let ( is_new, name) = cache. insert ( & parse) ;
466
+ assert ! ( is_new) ;
467
+
468
+ assert ! ( cache. update_and_set_rewrite_plan(
469
+ & name,
470
+ "SELECT 1 ORDER BY 1" ,
471
+ RewritePlan :: default ( )
472
+ ) ) ;
473
+
474
+ let key = cache
475
+ . statements ( )
476
+ . keys ( )
477
+ . next ( )
478
+ . expect ( "statement key missing" ) ;
479
+ assert_eq ! ( key. query( ) . unwrap( ) , "SELECT 1" ) ;
480
+ assert_eq ! ( cache. query( & name) . unwrap( ) , "SELECT 1 ORDER BY 1" ) ;
481
+
482
+ let parse_again = Parse :: named ( "__sqlx_2" , "SELECT 1" ) ;
483
+ let ( is_new_again, reused_name) = cache. insert ( & parse_again) ;
484
+ assert ! ( !is_new_again) ;
485
+ assert_eq ! ( reused_name, name) ;
486
+ assert_eq ! ( cache. len( ) , 1 ) ;
487
+ }
425
488
}
0 commit comments