@@ -7,23 +7,28 @@ use cost_model::CostModel;
77use sqlx:: { postgres:: PgListener , PgPool } ;
88use std:: {
99 cmp:: min,
10- collections:: HashMap ,
10+ collections:: { hash_map :: Entry , HashMap , VecDeque } ,
1111 str:: FromStr ,
1212 sync:: { Arc , Mutex , RwLock } ,
1313 time:: Duration ,
1414} ;
15- use thegraph_core:: DeploymentId ;
15+ use thegraph_core:: { DeploymentId , ParseDeploymentIdError } ;
16+ use tokio:: { sync:: mpsc:: channel, task:: JoinHandle , time:: sleep} ;
1617use tracing:: error;
17- use ttl_cache:: TtlCache ;
1818
1919use tap_core:: receipt:: {
2020 checks:: { Check , CheckError , CheckResult } ,
2121 state:: Checking ,
2222 Context , ReceiptWithState ,
2323} ;
2424
25+ // we only accept receipts with minimal 1 wei grt
26+ const MINIMAL_VALUE : u128 = 1 ;
27+
28+ type CostModelMap = Arc < RwLock < HashMap < DeploymentId , RwLock < CostModelCache > > > > ;
29+
2530pub struct MinimumValue {
26- cost_model_cache : Arc < RwLock < HashMap < DeploymentId , Mutex < CostModelCache > > > > ,
31+ cost_model_cache : CostModelMap ,
2732 watcher_cancel_token : tokio_util:: sync:: CancellationToken ,
2833}
2934
@@ -37,9 +42,7 @@ impl Drop for MinimumValue {
3742
3843impl MinimumValue {
3944 pub async fn new ( pgpool : PgPool ) -> Self {
40- let cost_model_cache = Arc :: new ( RwLock :: new (
41- HashMap :: < DeploymentId , Mutex < CostModelCache > > :: new ( ) ,
42- ) ) ;
45+ let cost_model_cache: CostModelMap = Default :: default ( ) ;
4346
4447 let mut pglistener = PgListener :: connect_with ( & pgpool. clone ( ) ) . await . unwrap ( ) ;
4548 pglistener. listen ( "cost_models_update_notify" ) . await . expect (
@@ -64,28 +67,48 @@ impl MinimumValue {
6467 fn get_expected_value ( & self , agora_query : & AgoraQuery ) -> anyhow:: Result < u128 > {
6568 // get agora model for the allocation_id
6669 let cache = self . cost_model_cache . read ( ) . unwrap ( ) ;
67- // on average, we'll have zero or one model
6870 let models = cache. get ( & agora_query. deployment_id ) ;
6971
7072 let expected_value = models
71- . map ( |cache| cache. lock ( ) . unwrap ( ) . cost ( agora_query) )
72- . unwrap_or_default ( ) ;
73+ . map ( |cache| {
74+ let cache = cache. read ( ) . unwrap ( ) ;
75+ cache. cost ( agora_query)
76+ } )
77+ . unwrap_or ( MINIMAL_VALUE ) ;
7378
7479 Ok ( expected_value)
7580 }
7681
7782 async fn cost_models_watcher (
7883 pgpool : PgPool ,
7984 mut pglistener : PgListener ,
80- cost_model_cache : Arc < RwLock < HashMap < DeploymentId , Mutex < CostModelCache > > > > ,
85+ cost_model_cache : CostModelMap ,
8186 cancel_token : tokio_util:: sync:: CancellationToken ,
8287 ) {
88+ let handles: Arc < Mutex < HashMap < DeploymentId , VecDeque < JoinHandle < ( ) > > > > > =
89+ Default :: default ( ) ;
90+ let ( tx, mut rx) = channel :: < DeploymentId > ( 64 ) ;
91+
8392 loop {
8493 tokio:: select! {
8594 _ = cancel_token. cancelled( ) => {
8695 break ;
8796 }
97+ Some ( deployment_id) = rx. recv( ) => {
98+ let mut cost_model_write = cost_model_cache. write( ) . unwrap( ) ;
99+ if let Some ( cache) = cost_model_write. get_mut( & deployment_id) {
100+ cache. get_mut( ) . unwrap( ) . expire( ) ;
101+ }
88102
103+ if let Entry :: Occupied ( mut entry) = handles. lock( ) . unwrap( ) . entry( deployment_id) {
104+ let vec = entry. get_mut( ) ;
105+ vec. pop_front( ) ;
106+ if vec. is_empty( ) {
107+ entry. remove( ) ;
108+ }
109+ }
110+
111+ }
89112 pg_notification = pglistener. recv( ) => {
90113 let pg_notification = pg_notification. expect(
91114 "should be able to receive Postgres Notify events on the channel \
@@ -103,31 +126,38 @@ impl MinimumValue {
103126 match cost_model_notification. tg_op. as_str( ) {
104127 "INSERT" => {
105128 let cost_model_source: CostModelSource = cost_model_notification. into( ) ;
106- let mut cost_model_cache = cost_model_cache
107- . write( )
108- . unwrap( ) ;
109-
110- match cost_model_cache. get_mut( & deployment_id) {
111- Some ( cache) => {
112- let _ = cache. lock( ) . unwrap( ) . insert_model( cost_model_source) ;
113- } ,
114- None => {
115- if let Ok ( cache) = CostModelCache :: new( cost_model_source) . inspect_err( |err| {
116- tracing:: error!(
117- "Error while compiling cost model for deployment id {}. Error: {}" ,
118- deployment_id, err
119- )
120- } ) {
121- cost_model_cache. insert( deployment_id, Mutex :: new( cache) ) ;
122- }
123- } ,
129+ {
130+ let mut cost_model_write = cost_model_cache
131+ . write( )
132+ . unwrap( ) ;
133+ let cache = cost_model_write. entry( deployment_id) . or_default( ) ;
134+ let _ = cache. get_mut( ) . unwrap( ) . insert_model( cost_model_source) ;
124135 }
136+ let _tx = tx. clone( ) ;
137+
138+ // expire after 60 seconds
139+ handles. lock( )
140+ . unwrap( )
141+ . entry( deployment_id)
142+ . or_default( )
143+ . push_back( tokio:: spawn( async move {
144+ // 1 minute after, we expire the older cache
145+ sleep( Duration :: from_secs( 60 ) ) . await ;
146+ let _ = _tx. send( deployment_id) . await ;
147+ } ) ) ;
125148 }
126149 "DELETE" => {
127- cost_model_cache
128- . write( )
129- . unwrap( )
130- . remove( & cost_model_notification. deployment) ;
150+ if let Entry :: Occupied ( mut entry) = cost_model_cache
151+ . write( ) . unwrap( ) . entry( cost_model_notification. deployment) {
152+ let should_remove = {
153+ let mut cost_model = entry. get_mut( ) . write( ) . unwrap( ) ;
154+ cost_model. expire( ) ;
155+ cost_model. is_empty( )
156+ } ;
157+ if should_remove {
158+ entry. remove( ) ;
159+ }
160+ }
131161 }
132162 // UPDATE and TRUNCATE are not expected to happen. Reload the entire cost
133163 // model cache.
@@ -138,6 +168,17 @@ impl MinimumValue {
138168 cost_model_notification. tg_op
139169 ) ;
140170
171+ {
172+ // clear all pending expire
173+ let mut handles = handles. lock( ) . unwrap( ) ;
174+ for maps in handles. values( ) {
175+ for handle in maps {
176+ handle. abort( ) ;
177+ }
178+ }
179+ handles. clear( ) ;
180+ }
181+
141182 Self :: value_check_reload( & pgpool, cost_model_cache. clone( ) )
142183 . await
143184 . expect( "should be able to reload cost models" )
@@ -150,7 +191,7 @@ impl MinimumValue {
150191
151192 async fn value_check_reload (
152193 pgpool : & PgPool ,
153- cost_model_cache : Arc < RwLock < HashMap < DeploymentId , Mutex < CostModelCache > > > > ,
194+ cost_model_cache : CostModelMap ,
154195 ) -> anyhow:: Result < ( ) > {
155196 let models = sqlx:: query!(
156197 r#"
@@ -166,13 +207,14 @@ impl MinimumValue {
166207 . into_iter ( )
167208 . map ( |record| {
168209 let deployment_id = DeploymentId :: from_str ( & record. deployment . unwrap ( ) ) ?;
169- let model = CostModelCache :: new ( CostModelSource {
210+ let mut model = CostModelCache :: default ( ) ;
211+ let _ = model. insert_model ( CostModelSource {
170212 deployment_id,
171213 model : record. model . unwrap ( ) ,
172- variables : record. variables . unwrap ( ) . to_string ( ) ,
173- } ) ? ;
214+ variables : record. variables . unwrap_or_default ( ) ,
215+ } ) ;
174216
175- Ok :: < _ , anyhow :: Error > ( ( deployment_id, Mutex :: new ( model) ) )
217+ Ok :: < _ , ParseDeploymentIdError > ( ( deployment_id, RwLock :: new ( model) ) )
176218 } )
177219 . collect :: < Result < HashMap < _ , _ > , _ > > ( ) ?;
178220
@@ -220,7 +262,7 @@ fn compile_cost_model(src: CostModelSource) -> anyhow::Result<CostModel> {
220262 if src. model . len ( ) > ( 1 << 16 ) {
221263 return Err ( anyhow ! ( "CostModelTooLarge" ) ) ;
222264 }
223- let model = CostModel :: compile ( & src. model , & src. variables ) ?;
265+ let model = CostModel :: compile ( & src. model , & src. variables . to_string ( ) ) ?;
224266 Ok ( model)
225267}
226268
@@ -231,18 +273,18 @@ pub struct AgoraQuery {
231273}
232274
233275#[ derive( Clone , Eq , Hash , PartialEq ) ]
234- pub struct CostModelSource {
276+ struct CostModelSource {
235277 pub deployment_id : DeploymentId ,
236278 pub model : String ,
237- pub variables : String ,
279+ pub variables : serde_json :: Value ,
238280}
239281
240282#[ derive( serde:: Deserialize ) ]
241283struct CostModelNotification {
242284 tg_op : String ,
243285 deployment : DeploymentId ,
244286 model : String ,
245- variables : String ,
287+ variables : serde_json :: Value ,
246288}
247289
248290impl From < CostModelNotification > for CostModelSource {
@@ -255,48 +297,29 @@ impl From<CostModelNotification> for CostModelSource {
255297 }
256298}
257299
258- pub struct CostModelCache {
259- models : TtlCache < CostModelSource , CostModel > ,
260- latest_model : CostModel ,
261- latest_source : CostModelSource ,
300+ #[ derive( Default ) ]
301+ struct CostModelCache {
302+ models : VecDeque < CostModel > ,
262303}
263304
264305impl CostModelCache {
265- pub fn new ( source : CostModelSource ) -> anyhow:: Result < Self > {
266- let model = compile_cost_model ( source. clone ( ) ) ?;
267- Ok ( Self {
268- latest_model : model,
269- latest_source : source,
270- // arbitrary number of models copy
271- models : TtlCache :: new ( 10 ) ,
272- } )
273- }
274-
275306 fn insert_model ( & mut self , source : CostModelSource ) -> anyhow:: Result < ( ) > {
276- if source != self . latest_source {
277- let model = compile_cost_model ( source. clone ( ) ) ?;
278- // update latest and insert into ttl the old model
279- let old_model = std:: mem:: replace ( & mut self . latest_model , model) ;
280- self . latest_source = source. clone ( ) ;
281-
282- self . models
283- // arbitrary cache duration
284- . insert ( source, old_model, Duration :: from_secs ( 60 ) ) ;
285- }
307+ let model = compile_cost_model ( source. clone ( ) ) ?;
308+ self . models . push_back ( model) ;
286309 Ok ( ( ) )
287310 }
288311
289- fn get_models ( & mut self ) -> Vec < & CostModel > {
290- let mut values: Vec < & CostModel > = self . models . iter ( ) . map ( |( _, v) | v) . collect ( ) ;
291- values. push ( & self . latest_model ) ;
292- values
312+ fn expire ( & mut self ) {
313+ self . models . pop_front ( ) ;
293314 }
294315
295- fn cost ( & mut self , query : & AgoraQuery ) -> u128 {
296- let models = self . get_models ( ) ;
316+ fn is_empty ( & self ) -> bool {
317+ self . models . is_empty ( )
318+ }
297319
298- models
299- . into_iter ( )
320+ fn cost ( & self , query : & AgoraQuery ) -> u128 {
321+ self . models
322+ . iter ( )
300323 . fold ( None , |acc, model| {
301324 let value = model
302325 . cost ( & query. query , & query. variables )
@@ -310,7 +333,7 @@ impl CostModelCache {
310333 Some ( value)
311334 }
312335 } )
313- . unwrap_or_default ( )
336+ . unwrap_or ( MINIMAL_VALUE )
314337 }
315338}
316339
@@ -319,17 +342,17 @@ mod tests {
319342 use sqlx:: PgPool ;
320343
321344 #[ sqlx:: test( migrations = "../migrations" ) ]
322- async fn initialize_check ( pg_pool : PgPool ) { }
345+ async fn initialize_check ( _pg_pool : PgPool ) { }
323346
324347 #[ sqlx:: test( migrations = "../migrations" ) ]
325- async fn should_initialize_check_with_caches ( pg_pool : PgPool ) { }
348+ async fn should_initialize_check_with_caches ( _pg_pool : PgPool ) { }
326349
327350 #[ sqlx:: test( migrations = "../migrations" ) ]
328- async fn should_add_model_to_cache_on_insert ( pg_pool : PgPool ) { }
351+ async fn should_add_model_to_cache_on_insert ( _pg_pool : PgPool ) { }
329352
330353 #[ sqlx:: test( migrations = "../migrations" ) ]
331- async fn should_expire_old_model ( pg_pool : PgPool ) { }
354+ async fn should_expire_old_model ( _pg_pool : PgPool ) { }
332355
333356 #[ sqlx:: test( migrations = "../migrations" ) ]
334- async fn should_verify_global_model ( pg_pool : PgPool ) { }
357+ async fn should_verify_global_model ( _pg_pool : PgPool ) { }
335358}
0 commit comments