@@ -8,10 +8,12 @@ use sqlx::{
8
8
postgres:: { PgListener , PgNotification } ,
9
9
PgPool ,
10
10
} ;
11
+ use std:: time:: Duration ;
11
12
use std:: {
12
13
collections:: HashMap ,
13
14
str:: FromStr ,
14
15
sync:: { Arc , RwLock } ,
16
+ time:: Instant ,
15
17
} ;
16
18
use thegraph_core:: DeploymentId ;
17
19
use tracing:: error;
@@ -39,6 +41,7 @@ pub struct AgoraQuery {
39
41
40
42
type CostModelMap = Arc < RwLock < HashMap < DeploymentId , CostModel > > > ;
41
43
type GlobalModel = Arc < RwLock < Option < CostModel > > > ;
44
+ type GracePeriod = Arc < RwLock < Instant > > ;
42
45
43
46
/// Represents the check for minimum for a receipt
44
47
///
@@ -48,13 +51,16 @@ pub struct MinimumValue {
48
51
cost_model_map : CostModelMap ,
49
52
global_model : GlobalModel ,
50
53
watcher_cancel_token : tokio_util:: sync:: CancellationToken ,
54
+ updated_at : GracePeriod ,
55
+ grace_period : Duration ,
51
56
}
52
57
53
58
struct CostModelWatcher {
54
59
pgpool : PgPool ,
55
60
56
61
cost_models : CostModelMap ,
57
62
global_model : GlobalModel ,
63
+ updated_at : GracePeriod ,
58
64
}
59
65
60
66
impl CostModelWatcher {
@@ -64,11 +70,13 @@ impl CostModelWatcher {
64
70
cost_models : CostModelMap ,
65
71
global_model : GlobalModel ,
66
72
cancel_token : tokio_util:: sync:: CancellationToken ,
73
+ grace_period : GracePeriod ,
67
74
) {
68
75
let cost_model_watcher = CostModelWatcher {
69
76
pgpool,
70
77
global_model,
71
78
cost_models,
79
+ updated_at : grace_period,
72
80
} ;
73
81
74
82
loop {
@@ -123,6 +131,8 @@ impl CostModelWatcher {
123
131
}
124
132
} ,
125
133
} ;
134
+
135
+ * self . updated_at . write ( ) . unwrap ( ) = Instant :: now ( ) ;
126
136
}
127
137
128
138
fn handle_delete ( & self , deployment : String ) {
@@ -142,6 +152,7 @@ impl CostModelWatcher {
142
152
}
143
153
} ,
144
154
} ;
155
+ * self . updated_at . write ( ) . unwrap ( ) = Instant :: now ( ) ;
145
156
}
146
157
147
158
async fn handle_unexpected_notification ( & self , payload : & str ) {
@@ -157,7 +168,9 @@ impl CostModelWatcher {
157
168
self . global_model . clone ( ) ,
158
169
)
159
170
. await
160
- . expect ( "should be able to reload cost models" )
171
+ . expect ( "should be able to reload cost models" ) ;
172
+
173
+ * self . updated_at . write ( ) . unwrap ( ) = Instant :: now ( ) ;
161
174
}
162
175
}
163
176
@@ -170,9 +183,10 @@ impl Drop for MinimumValue {
170
183
}
171
184
172
185
impl MinimumValue {
173
- pub async fn new ( pgpool : PgPool ) -> Self {
186
+ pub async fn new ( pgpool : PgPool , grace_period : Duration ) -> Self {
174
187
let cost_model_map: CostModelMap = Default :: default ( ) ;
175
188
let global_model: GlobalModel = Default :: default ( ) ;
189
+ let updated_at: GracePeriod = Arc :: new ( RwLock :: new ( Instant :: now ( ) ) ) ;
176
190
Self :: value_check_reload ( & pgpool, cost_model_map. clone ( ) , global_model. clone ( ) )
177
191
. await
178
192
. expect ( "should be able to reload cost models" ) ;
@@ -193,15 +207,22 @@ impl MinimumValue {
193
207
cost_model_map. clone ( ) ,
194
208
global_model. clone ( ) ,
195
209
watcher_cancel_token. clone ( ) ,
210
+ updated_at. clone ( ) ,
196
211
) ) ;
197
-
198
212
Self {
199
213
global_model,
200
214
cost_model_map,
201
215
watcher_cancel_token,
216
+ updated_at,
217
+ grace_period,
202
218
}
203
219
}
204
220
221
+ fn inside_grace_period ( & self ) -> bool {
222
+ let time_elapsed = Instant :: now ( ) . duration_since ( * self . updated_at . read ( ) . unwrap ( ) ) ;
223
+ time_elapsed < self . grace_period
224
+ }
225
+
205
226
fn expected_value ( & self , agora_query : & AgoraQuery ) -> anyhow:: Result < u128 > {
206
227
// get agora model for the deployment_id
207
228
let model = self . cost_model_map . read ( ) . unwrap ( ) ;
@@ -271,14 +292,17 @@ impl Check for MinimumValue {
271
292
let agora_query = ctx
272
293
. get ( )
273
294
. ok_or ( CheckError :: Failed ( anyhow ! ( "Could not find agora query" ) ) ) ?;
295
+ // get value
296
+ let value = receipt. signed_receipt ( ) . message . value ;
297
+
298
+ if self . inside_grace_period ( ) && value >= MINIMAL_VALUE {
299
+ return Ok ( ( ) ) ;
300
+ }
274
301
275
302
let expected_value = self
276
303
. expected_value ( agora_query)
277
304
. map_err ( CheckError :: Failed ) ?;
278
305
279
- // get value
280
- let value = receipt. signed_receipt ( ) . message . value ;
281
-
282
306
let should_accept = value >= expected_value;
283
307
284
308
tracing:: trace!(
@@ -339,7 +363,7 @@ mod tests {
339
363
340
364
#[ sqlx:: test( migrations = "../migrations" ) ]
341
365
async fn initialize_check ( pgpool : PgPool ) {
342
- let check = MinimumValue :: new ( pgpool) . await ;
366
+ let check = MinimumValue :: new ( pgpool, Duration :: from_secs ( 0 ) ) . await ;
343
367
assert_eq ! ( check. cost_model_map. read( ) . unwrap( ) . len( ) , 0 ) ;
344
368
}
345
369
@@ -350,7 +374,7 @@ mod tests {
350
374
351
375
add_cost_models ( & pgpool, to_db_models ( test_models. clone ( ) ) ) . await ;
352
376
353
- let check = MinimumValue :: new ( pgpool) . await ;
377
+ let check = MinimumValue :: new ( pgpool, Duration :: from_secs ( 0 ) ) . await ;
354
378
assert_eq ! ( check. cost_model_map. read( ) . unwrap( ) . len( ) , 2 ) ;
355
379
356
380
// no global model
@@ -359,7 +383,7 @@ mod tests {
359
383
360
384
#[ sqlx:: test( migrations = "../migrations" ) ]
361
385
async fn should_watch_model_insert ( pgpool : PgPool ) {
362
- let check = MinimumValue :: new ( pgpool. clone ( ) ) . await ;
386
+ let check = MinimumValue :: new ( pgpool. clone ( ) , Duration :: from_secs ( 0 ) ) . await ;
363
387
assert_eq ! ( check. cost_model_map. read( ) . unwrap( ) . len( ) , 0 ) ;
364
388
365
389
// insert 2 cost models for different deployment_id
@@ -379,7 +403,7 @@ mod tests {
379
403
let test_models = crate :: cost_model:: test:: test_data ( ) ;
380
404
add_cost_models ( & pgpool, to_db_models ( test_models. clone ( ) ) ) . await ;
381
405
382
- let check = MinimumValue :: new ( pgpool. clone ( ) ) . await ;
406
+ let check = MinimumValue :: new ( pgpool. clone ( ) , Duration :: from_secs ( 0 ) ) . await ;
383
407
assert_eq ! ( check. cost_model_map. read( ) . unwrap( ) . len( ) , 2 ) ;
384
408
385
409
// remove
@@ -398,13 +422,13 @@ mod tests {
398
422
let global_model = global_cost_model ( ) ;
399
423
add_cost_models ( & pgpool, vec ! [ global_model. clone( ) ] ) . await ;
400
424
401
- let check = MinimumValue :: new ( pgpool. clone ( ) ) . await ;
425
+ let check = MinimumValue :: new ( pgpool. clone ( ) , Duration :: from_secs ( 0 ) ) . await ;
402
426
assert ! ( check. global_model. read( ) . unwrap( ) . is_some( ) ) ;
403
427
}
404
428
405
429
#[ sqlx:: test( migrations = "../migrations" ) ]
406
430
async fn should_watch_global_model ( pgpool : PgPool ) {
407
- let check = MinimumValue :: new ( pgpool. clone ( ) ) . await ;
431
+ let check = MinimumValue :: new ( pgpool. clone ( ) , Duration :: from_secs ( 0 ) ) . await ;
408
432
409
433
let global_model = global_cost_model ( ) ;
410
434
add_cost_models ( & pgpool, vec ! [ global_model. clone( ) ] ) . await ;
@@ -418,7 +442,7 @@ mod tests {
418
442
let global_model = global_cost_model ( ) ;
419
443
add_cost_models ( & pgpool, vec ! [ global_model. clone( ) ] ) . await ;
420
444
421
- let check = MinimumValue :: new ( pgpool. clone ( ) ) . await ;
445
+ let check = MinimumValue :: new ( pgpool. clone ( ) , Duration :: from_secs ( 0 ) ) . await ;
422
446
assert ! ( check. global_model. read( ) . unwrap( ) . is_some( ) ) ;
423
447
424
448
sqlx:: query!( r#"DELETE FROM "CostModels""# )
@@ -440,7 +464,7 @@ mod tests {
440
464
441
465
add_cost_models ( & pgpool, to_db_models ( test_models. clone ( ) ) ) . await ;
442
466
443
- let check = MinimumValue :: new ( pgpool) . await ;
467
+ let check = MinimumValue :: new ( pgpool, Duration :: from_secs ( 1 ) ) . await ;
444
468
445
469
let deployment_id = test_models[ 0 ] . deployment ;
446
470
let mut ctx = Context :: new ( ) ;
@@ -477,6 +501,14 @@ mod tests {
477
501
let signed_receipt =
478
502
create_signed_receipt ( ALLOCATION_ID , u64:: MAX , u64:: MAX , minimal_value - 1 ) . await ;
479
503
let receipt = ReceiptWithState :: new ( signed_receipt) ;
504
+
505
+ assert ! (
506
+ check. check( & ctx, & receipt) . await . is_ok( ) ,
507
+ "Should accept since its inside grace period "
508
+ ) ;
509
+
510
+ sleep ( Duration :: from_millis ( 1010 ) ) . await ;
511
+
480
512
assert ! (
481
513
check. check( & ctx, & receipt) . await . is_err( ) ,
482
514
"Should require minimal value"
@@ -508,7 +540,7 @@ mod tests {
508
540
add_cost_models ( & pgpool, vec ! [ global_model. clone( ) ] ) . await ;
509
541
add_cost_models ( & pgpool, to_db_models ( test_models. clone ( ) ) ) . await ;
510
542
511
- let check = MinimumValue :: new ( pgpool) . await ;
543
+ let check = MinimumValue :: new ( pgpool, Duration :: from_secs ( 0 ) ) . await ;
512
544
513
545
let deployment_id = test_models[ 0 ] . deployment ;
514
546
let mut ctx = Context :: new ( ) ;
0 commit comments