@@ -8,7 +8,8 @@ use sqlx::{postgres::PgListener, PgPool};
88use std:: {
99 cmp:: min,
1010 collections:: HashMap ,
11- sync:: { Arc , Mutex } ,
11+ str:: FromStr ,
12+ sync:: { Arc , Mutex , RwLock } ,
1213 time:: Duration ,
1314} ;
1415use thegraph_core:: DeploymentId ;
@@ -22,7 +23,7 @@ use tap_core::receipt::{
2223} ;
2324
2425pub struct MinimumValue {
25- cost_model_cache : Arc < Mutex < HashMap < DeploymentId , CostModelCache > > > ,
26+ cost_model_cache : Arc < RwLock < HashMap < DeploymentId , Mutex < CostModelCache > > > > ,
2627 watcher_cancel_token : tokio_util:: sync:: CancellationToken ,
2728}
2829
@@ -36,7 +37,9 @@ impl Drop for MinimumValue {
3637
3738impl MinimumValue {
3839 pub async fn new ( pgpool : PgPool ) -> Self {
39- let cost_model_cache = Arc :: new ( Mutex :: new ( HashMap :: < DeploymentId , CostModelCache > :: new ( ) ) ) ;
40+ let cost_model_cache = Arc :: new ( RwLock :: new (
41+ HashMap :: < DeploymentId , Mutex < CostModelCache > > :: new ( ) ,
42+ ) ) ;
4043
4144 let mut pglistener = PgListener :: connect_with ( & pgpool. clone ( ) ) . await . unwrap ( ) ;
4245 pglistener. listen ( "cost_models_update_notify" ) . await . expect (
@@ -58,10 +61,23 @@ impl MinimumValue {
5861 }
5962 }
6063
64+ fn get_expected_value ( & self , agora_query : & AgoraQuery ) -> anyhow:: Result < u128 > {
65+ // get agora model for the allocation_id
66+ let cache = self . cost_model_cache . read ( ) . unwrap ( ) ;
67+ // on average, we'll have zero or one model
68+ let models = cache. get ( & agora_query. deployment_id ) ;
69+
70+ let expected_value = models
71+ . map ( |cache| cache. lock ( ) . unwrap ( ) . cost ( agora_query) )
72+ . unwrap_or_default ( ) ;
73+
74+ Ok ( expected_value)
75+ }
76+
6177 async fn cost_models_watcher (
62- _pgpool : PgPool ,
78+ pgpool : PgPool ,
6379 mut pglistener : PgListener ,
64- cost_model_cache : Arc < Mutex < HashMap < DeploymentId , CostModelCache > > > ,
80+ cost_model_cache : Arc < RwLock < HashMap < DeploymentId , Mutex < CostModelCache > > > > ,
6581 cancel_token : tokio_util:: sync:: CancellationToken ,
6682 ) {
6783 loop {
@@ -88,12 +104,12 @@ impl MinimumValue {
88104 "INSERT" => {
89105 let cost_model_source: CostModelSource = cost_model_notification. into( ) ;
90106 let mut cost_model_cache = cost_model_cache
91- . lock ( )
107+ . write ( )
92108 . unwrap( ) ;
93109
94110 match cost_model_cache. get_mut( & deployment_id) {
95111 Some ( cache) => {
96- let _ = cache. insert_model( cost_model_source) ;
112+ let _ = cache. lock ( ) . unwrap ( ) . insert_model( cost_model_source) ;
97113 } ,
98114 None => {
99115 if let Ok ( cache) = CostModelCache :: new( cost_model_source) . inspect_err( |err| {
@@ -102,14 +118,14 @@ impl MinimumValue {
102118 deployment_id, err
103119 )
104120 } ) {
105- cost_model_cache. insert( deployment_id, cache) ;
121+ cost_model_cache. insert( deployment_id, Mutex :: new ( cache) ) ;
106122 }
107123 } ,
108124 }
109125 }
110126 "DELETE" => {
111127 cost_model_cache
112- . lock ( )
128+ . write ( )
113129 . unwrap( )
114130 . remove( & cost_model_notification. deployment) ;
115131 }
@@ -122,29 +138,47 @@ impl MinimumValue {
122138 cost_model_notification. tg_op
123139 ) ;
124140
125- // Self::sender_denylist_reload( pgpool.clone(), denylist .clone())
126- // .await
127- // .expect("should be able to reload cost models")
141+ Self :: value_check_reload ( & pgpool, cost_model_cache . clone( ) )
142+ . await
143+ . expect( "should be able to reload cost models" )
128144 }
129145 }
130146 }
131147 }
132148 }
133149 }
134- }
135150
136- impl MinimumValue {
137- fn get_expected_value ( & self , agora_query : & AgoraQuery ) -> anyhow:: Result < u128 > {
138- // get agora model for the allocation_id
139- let mut cache = self . cost_model_cache . lock ( ) . unwrap ( ) ;
140- // on average, we'll have zero or one model
141- let models = cache. get_mut ( & agora_query. deployment_id ) ;
151+ async fn value_check_reload (
152+ pgpool : & PgPool ,
153+ cost_model_cache : Arc < RwLock < HashMap < DeploymentId , Mutex < CostModelCache > > > > ,
154+ ) -> anyhow:: Result < ( ) > {
155+ let models = sqlx:: query!(
156+ r#"
157+ SELECT deployment, model, variables
158+ FROM "CostModels"
159+ WHERE deployment != 'global'
160+ ORDER BY deployment ASC
161+ "#
162+ )
163+ . fetch_all ( pgpool)
164+ . await ?;
165+ let models = models
166+ . into_iter ( )
167+ . map ( |record| {
168+ let deployment_id = DeploymentId :: from_str ( & record. deployment . unwrap ( ) ) ?;
169+ let model = CostModelCache :: new ( CostModelSource {
170+ deployment_id,
171+ model : record. model . unwrap ( ) ,
172+ variables : record. variables . unwrap ( ) . to_string ( ) ,
173+ } ) ?;
174+
175+ Ok :: < _ , anyhow:: Error > ( ( deployment_id, Mutex :: new ( model) ) )
176+ } )
177+ . collect :: < Result < HashMap < _ , _ > , _ > > ( ) ?;
142178
143- let expected_value = models
144- . map ( |cache| cache. cost ( agora_query) )
145- . unwrap_or_default ( ) ;
179+ * ( cost_model_cache. write ( ) . unwrap ( ) ) = models;
146180
147- Ok ( expected_value )
181+ Ok ( ( ) )
148182 }
149183}
150184
@@ -279,3 +313,23 @@ impl CostModelCache {
279313 . unwrap_or_default ( )
280314 }
281315}
316+
317+ #[ cfg( test) ]
318+ mod tests {
319+ use sqlx:: PgPool ;
320+
321+ #[ sqlx:: test( migrations = "../migrations" ) ]
322+ async fn initialize_check ( pg_pool : PgPool ) { }
323+
324+ #[ sqlx:: test( migrations = "../migrations" ) ]
325+ async fn should_initialize_check_with_caches ( pg_pool : PgPool ) { }
326+
327+ #[ sqlx:: test( migrations = "../migrations" ) ]
328+ async fn should_add_model_to_cache_on_insert ( pg_pool : PgPool ) { }
329+
330+ #[ sqlx:: test( migrations = "../migrations" ) ]
331+ async fn should_expire_old_model ( pg_pool : PgPool ) { }
332+
333+ #[ sqlx:: test( migrations = "../migrations" ) ]
334+ async fn should_verify_global_model ( pg_pool : PgPool ) { }
335+ }
0 commit comments