@@ -8,10 +8,12 @@ use sqlx::{
88 postgres:: { PgListener , PgNotification } ,
99 PgPool ,
1010} ;
11+ use std:: time:: Duration ;
1112use std:: {
1213 collections:: HashMap ,
1314 str:: FromStr ,
1415 sync:: { Arc , RwLock } ,
16+ time:: Instant ,
1517} ;
1618use thegraph_core:: DeploymentId ;
1719use tracing:: error;
@@ -39,6 +41,7 @@ pub struct AgoraQuery {
3941
4042type CostModelMap = Arc < RwLock < HashMap < DeploymentId , CostModel > > > ;
4143type GlobalModel = Arc < RwLock < Option < CostModel > > > ;
44+ type GracePeriod = Arc < RwLock < Instant > > ;
4245
4346/// Represents the check for minimum for a receipt
4447///
@@ -48,13 +51,16 @@ pub struct MinimumValue {
4851 cost_model_map : CostModelMap ,
4952 global_model : GlobalModel ,
5053 watcher_cancel_token : tokio_util:: sync:: CancellationToken ,
54+ updated_at : GracePeriod ,
55+ grace_period : Duration ,
5156}
5257
5358struct CostModelWatcher {
5459 pgpool : PgPool ,
5560
5661 cost_models : CostModelMap ,
5762 global_model : GlobalModel ,
63+ updated_at : GracePeriod ,
5864}
5965
6066impl CostModelWatcher {
@@ -64,11 +70,13 @@ impl CostModelWatcher {
6470 cost_models : CostModelMap ,
6571 global_model : GlobalModel ,
6672 cancel_token : tokio_util:: sync:: CancellationToken ,
73+ grace_period : GracePeriod ,
6774 ) {
6875 let cost_model_watcher = CostModelWatcher {
6976 pgpool,
7077 global_model,
7178 cost_models,
79+ updated_at : grace_period,
7280 } ;
7381
7482 loop {
@@ -123,6 +131,8 @@ impl CostModelWatcher {
123131 }
124132 } ,
125133 } ;
134+
135+ * self . updated_at . write ( ) . unwrap ( ) = Instant :: now ( ) ;
126136 }
127137
128138 fn handle_delete ( & self , deployment : String ) {
@@ -142,6 +152,7 @@ impl CostModelWatcher {
142152 }
143153 } ,
144154 } ;
155+ * self . updated_at . write ( ) . unwrap ( ) = Instant :: now ( ) ;
145156 }
146157
147158 async fn handle_unexpected_notification ( & self , payload : & str ) {
@@ -157,7 +168,9 @@ impl CostModelWatcher {
157168 self . global_model . clone ( ) ,
158169 )
159170 . await
160- . expect ( "should be able to reload cost models" )
171+ . expect ( "should be able to reload cost models" ) ;
172+
173+ * self . updated_at . write ( ) . unwrap ( ) = Instant :: now ( ) ;
161174 }
162175}
163176
@@ -170,9 +183,10 @@ impl Drop for MinimumValue {
170183}
171184
172185impl MinimumValue {
173- pub async fn new ( pgpool : PgPool ) -> Self {
186+ pub async fn new ( pgpool : PgPool , grace_period : Duration ) -> Self {
174187 let cost_model_map: CostModelMap = Default :: default ( ) ;
175188 let global_model: GlobalModel = Default :: default ( ) ;
189+ let updated_at: GracePeriod = Arc :: new ( RwLock :: new ( Instant :: now ( ) ) ) ;
176190 Self :: value_check_reload ( & pgpool, cost_model_map. clone ( ) , global_model. clone ( ) )
177191 . await
178192 . expect ( "should be able to reload cost models" ) ;
@@ -193,15 +207,22 @@ impl MinimumValue {
193207 cost_model_map. clone ( ) ,
194208 global_model. clone ( ) ,
195209 watcher_cancel_token. clone ( ) ,
210+ updated_at. clone ( ) ,
196211 ) ) ;
197-
198212 Self {
199213 global_model,
200214 cost_model_map,
201215 watcher_cancel_token,
216+ updated_at,
217+ grace_period,
202218 }
203219 }
204220
221+ fn inside_grace_period ( & self ) -> bool {
222+ let time_elapsed = Instant :: now ( ) . duration_since ( * self . updated_at . read ( ) . unwrap ( ) ) ;
223+ time_elapsed < self . grace_period
224+ }
225+
205226 fn expected_value ( & self , agora_query : & AgoraQuery ) -> anyhow:: Result < u128 > {
206227 // get agora model for the deployment_id
207228 let model = self . cost_model_map . read ( ) . unwrap ( ) ;
@@ -271,14 +292,17 @@ impl Check for MinimumValue {
271292 let agora_query = ctx
272293 . get ( )
273294 . ok_or ( CheckError :: Failed ( anyhow ! ( "Could not find agora query" ) ) ) ?;
295+ // get value
296+ let value = receipt. signed_receipt ( ) . message . value ;
297+
298+ if self . inside_grace_period ( ) && value >= MINIMAL_VALUE {
299+ return Ok ( ( ) ) ;
300+ }
274301
275302 let expected_value = self
276303 . expected_value ( agora_query)
277304 . map_err ( CheckError :: Failed ) ?;
278305
279- // get value
280- let value = receipt. signed_receipt ( ) . message . value ;
281-
282306 let should_accept = value >= expected_value;
283307
284308 tracing:: trace!(
@@ -339,7 +363,7 @@ mod tests {
339363
340364 #[ sqlx:: test( migrations = "../migrations" ) ]
341365 async fn initialize_check ( pgpool : PgPool ) {
342- let check = MinimumValue :: new ( pgpool) . await ;
366+ let check = MinimumValue :: new ( pgpool, Duration :: from_secs ( 0 ) ) . await ;
343367 assert_eq ! ( check. cost_model_map. read( ) . unwrap( ) . len( ) , 0 ) ;
344368 }
345369
@@ -350,7 +374,7 @@ mod tests {
350374
351375 add_cost_models ( & pgpool, to_db_models ( test_models. clone ( ) ) ) . await ;
352376
353- let check = MinimumValue :: new ( pgpool) . await ;
377+ let check = MinimumValue :: new ( pgpool, Duration :: from_secs ( 0 ) ) . await ;
354378 assert_eq ! ( check. cost_model_map. read( ) . unwrap( ) . len( ) , 2 ) ;
355379
356380 // no global model
@@ -359,7 +383,7 @@ mod tests {
359383
360384 #[ sqlx:: test( migrations = "../migrations" ) ]
361385 async fn should_watch_model_insert ( pgpool : PgPool ) {
362- let check = MinimumValue :: new ( pgpool. clone ( ) ) . await ;
386+ let check = MinimumValue :: new ( pgpool. clone ( ) , Duration :: from_secs ( 0 ) ) . await ;
363387 assert_eq ! ( check. cost_model_map. read( ) . unwrap( ) . len( ) , 0 ) ;
364388
365389 // insert 2 cost models for different deployment_id
@@ -379,7 +403,7 @@ mod tests {
379403 let test_models = crate :: cost_model:: test:: test_data ( ) ;
380404 add_cost_models ( & pgpool, to_db_models ( test_models. clone ( ) ) ) . await ;
381405
382- let check = MinimumValue :: new ( pgpool. clone ( ) ) . await ;
406+ let check = MinimumValue :: new ( pgpool. clone ( ) , Duration :: from_secs ( 0 ) ) . await ;
383407 assert_eq ! ( check. cost_model_map. read( ) . unwrap( ) . len( ) , 2 ) ;
384408
385409 // remove
@@ -398,13 +422,13 @@ mod tests {
398422 let global_model = global_cost_model ( ) ;
399423 add_cost_models ( & pgpool, vec ! [ global_model. clone( ) ] ) . await ;
400424
401- let check = MinimumValue :: new ( pgpool. clone ( ) ) . await ;
425+ let check = MinimumValue :: new ( pgpool. clone ( ) , Duration :: from_secs ( 0 ) ) . await ;
402426 assert ! ( check. global_model. read( ) . unwrap( ) . is_some( ) ) ;
403427 }
404428
405429 #[ sqlx:: test( migrations = "../migrations" ) ]
406430 async fn should_watch_global_model ( pgpool : PgPool ) {
407- let check = MinimumValue :: new ( pgpool. clone ( ) ) . await ;
431+ let check = MinimumValue :: new ( pgpool. clone ( ) , Duration :: from_secs ( 0 ) ) . await ;
408432
409433 let global_model = global_cost_model ( ) ;
410434 add_cost_models ( & pgpool, vec ! [ global_model. clone( ) ] ) . await ;
@@ -418,7 +442,7 @@ mod tests {
418442 let global_model = global_cost_model ( ) ;
419443 add_cost_models ( & pgpool, vec ! [ global_model. clone( ) ] ) . await ;
420444
421- let check = MinimumValue :: new ( pgpool. clone ( ) ) . await ;
445+ let check = MinimumValue :: new ( pgpool. clone ( ) , Duration :: from_secs ( 0 ) ) . await ;
422446 assert ! ( check. global_model. read( ) . unwrap( ) . is_some( ) ) ;
423447
424448 sqlx:: query!( r#"DELETE FROM "CostModels""# )
@@ -440,7 +464,7 @@ mod tests {
440464
441465 add_cost_models ( & pgpool, to_db_models ( test_models. clone ( ) ) ) . await ;
442466
443- let check = MinimumValue :: new ( pgpool) . await ;
467+ let check = MinimumValue :: new ( pgpool, Duration :: from_secs ( 1 ) ) . await ;
444468
445469 let deployment_id = test_models[ 0 ] . deployment ;
446470 let mut ctx = Context :: new ( ) ;
@@ -477,6 +501,14 @@ mod tests {
477501 let signed_receipt =
478502 create_signed_receipt ( ALLOCATION_ID , u64:: MAX , u64:: MAX , minimal_value - 1 ) . await ;
479503 let receipt = ReceiptWithState :: new ( signed_receipt) ;
504+
505+ assert ! (
506+ check. check( & ctx, & receipt) . await . is_ok( ) ,
507+ "Should accept since its inside grace period "
508+ ) ;
509+
510+ sleep ( Duration :: from_millis ( 1010 ) ) . await ;
511+
480512 assert ! (
481513 check. check( & ctx, & receipt) . await . is_err( ) ,
482514 "Should require minimal value"
@@ -508,7 +540,7 @@ mod tests {
508540 add_cost_models ( & pgpool, vec ! [ global_model. clone( ) ] ) . await ;
509541 add_cost_models ( & pgpool, to_db_models ( test_models. clone ( ) ) ) . await ;
510542
511- let check = MinimumValue :: new ( pgpool) . await ;
543+ let check = MinimumValue :: new ( pgpool, Duration :: from_secs ( 0 ) ) . await ;
512544
513545 let deployment_id = test_models[ 0 ] . deployment ;
514546 let mut ctx = Context :: new ( ) ;
0 commit comments